Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make lidarpcs batch query. #376

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions nuplan/database/nuplan_db/nuplan_scenario_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nuplan.database.nuplan_db.query_session import execute_many, execute_one
from nuplan.database.nuplan_db.sensor_data_table_row import SensorDataTableRow
from nuplan.database.utils.label.utils import local2agent_type, raw_mapping
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling


def _parse_tracked_object_row(row: sqlite3.Row) -> TrackedObject:
Expand Down Expand Up @@ -447,6 +448,66 @@ def get_sampled_lidarpcs_from_db(
yield LidarPc.from_db_row(row)


def get_sampled_lidarpcs_from_db_batch(
log_file: str,
initial_token: str,
sensor_source: SensorDataSource,
sample_indexes: List[int],
future: bool
) -> List[LidarPc]:
if not sample_indexes:
return []

sensor_token = get_sensor_token(log_file, sensor_source.sensor_table, sensor_source.channel)

order_direction = "ASC" if future else "DESC"
order_cmp = ">=" if future else "<="

query = f"""
WITH initial_lidarpc AS
(
SELECT token, timestamp
FROM lidar_pc
WHERE token = ?
),
ordered AS
(
SELECT lp.token,
lp.next_token,
lp.prev_token,
lp.ego_pose_token,
lp.lidar_token,
lp.scene_token,
lp.filename,
lp.timestamp,
ROW_NUMBER() OVER (ORDER BY lp.timestamp {order_direction}) AS row_num
FROM lidar_pc AS lp
CROSS JOIN initial_lidarpc AS il
WHERE lp.timestamp {order_cmp} il.timestamp
AND lp.lidar_token = ?
)
SELECT token,
next_token,
prev_token,
ego_pose_token,
lidar_token,
scene_token,
filename,
timestamp
FROM ordered

-- ROW_NUMBER() starts at 1, where consumers will expect sample_indexes to be 0-indexed
WHERE (row_num - 1) IN ({('?,'*len(sample_indexes))[:-1]})

ORDER BY timestamp ASC;
"""

args = [bytearray.fromhex(initial_token), bytearray.fromhex(sensor_token)] + sample_indexes # type: ignore
rows = execute_many(query, args, log_file)
return [LidarPc.from_db_row(row) for row in rows]



def get_sampled_ego_states_from_db(
log_file: str,
initial_token: str,
Expand Down Expand Up @@ -778,6 +839,56 @@ def get_future_waypoints_for_agents_from_db(
yield (row["track_token"].hex(), Waypoint(TimePoint(row["timestamp"]), oriented_box, velocity))


def get_future_waypoints_for_agents_from_db_optimized(
log_file: str, track_tokens: List[str], start_timestamp: int, future_trajectory_sampling: TrajectorySampling
) -> Generator[Tuple[str, Waypoint], None, None]:
"""
Obtain the future waypoints for the selected agents from the DB in the provided time window,
taking into account the sampling interval for future waypoints.

:param log_file: The log file to query.
:param track_tokens: The track_tokens for which to query.
:param start_timestamp: The starting timestamp for which to query.
:param future_trajectory_sampling: The trajectory sampling strategy.
:return: A generator of tuples of (track_token, Waypoint), sorted by track_token, then by timestamp in ascending order.
"""
interval_microseconds = int(1e6 * future_trajectory_sampling.interval_length)
end_timestamp = start_timestamp + int(1e6 * future_trajectory_sampling.time_horizon)

# Adjust the query to return waypoints based on the specified interval.
# The following SQL is an example and might need adjustments based on the actual schema.
query = f"""
WITH RECURSIVE sampled_timestamps(ts) AS (
SELECT ? UNION ALL
SELECT ts + ? FROM sampled_timestamps
WHERE ts + ? <= ?
)
SELECT
lb.x, lb.y, lb.z, lb.yaw, lb.width, lb.length, lb.height, lb.vx, lb.vy, lb.track_token, lp.timestamp
FROM
lidar_box AS lb
INNER JOIN
lidar_pc AS lp ON lp.token = lb.lidar_pc_token
INNER JOIN
sampled_timestamps st ON lp.timestamp >= st.ts AND lp.timestamp < st.ts + ?
WHERE
lp.timestamp >= ? AND lp.timestamp <= ? AND lb.track_token IN ({('?,' * len(track_tokens))[:-1]})
ORDER BY
lb.track_token ASC, lp.timestamp ASC;
"""
args = [start_timestamp, interval_microseconds, interval_microseconds, end_timestamp, interval_microseconds, start_timestamp, end_timestamp] + [bytearray.fromhex(t) for t in track_tokens]

for row in execute_many(query, args, log_file):
# 直接在这里解析行数据,创建Waypoint对象
pose = StateSE2(row["x"], row["y"], row["yaw"])
oriented_box = OrientedBox(pose, width=row["width"], length=row["length"], height=row["height"])
velocity = StateVector2D(row["vx"], row["vy"])
waypoint = Waypoint(TimePoint(row["timestamp"]), oriented_box, velocity)

# 产生(track_token, Waypoint)对
yield (row["track_token"].hex(), waypoint)


def get_scenarios_from_db(
log_file: str,
filter_tokens: Optional[List[str]],
Expand Down
67 changes: 46 additions & 21 deletions nuplan/database/nuplan_db/query_session.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,82 @@
import sqlite3
from typing import Any, Generator, Optional
from collections import OrderedDict

memory_dbs = OrderedDict()
MAX_CACHE_SIZE = 5 # 允许的最大缓存连接数

def execute_many(query_text: str, query_parameters: Any, db_file: str) -> Generator[sqlite3.Row, None, None]:
def get_or_copy_db_to_memory(db_file: str) -> sqlite3.Connection:
"""
Runs a query with the provided arguments on a specified Sqlite DB file.
This query can return any number of rows.
Get an existing in-memory database connection or copy the SQLite DB file to an in-memory database if not exists.
Manages cache size to not exceed MAX_CACHE_SIZE by removing the least recently used (LRU) connection.
:param db_file: The DB file to check or copy to memory.
:return: A connection to the in-memory database.
"""
# 如果已缓存,则将其移动到字典的末尾以标记为最近使用
print("memory dbs: {}, current,{}".format(memory_dbs,db_file) )
if db_file in memory_dbs:
memory_dbs.move_to_end(db_file)
return memory_dbs[db_file]

# 如果达到最大缓存大小,则删除最早的项
if len(memory_dbs) >= MAX_CACHE_SIZE:
oldest_db_file, oldest_conn = memory_dbs.popitem(last=False) # 删除第一个添加的项
oldest_conn.close()
print(f"Closed and removed the oldest DB from cache: {oldest_db_file}")

# 创建新的内存数据库连接
disk_connection = sqlite3.connect(db_file)
mem_connection = sqlite3.connect(':memory:')
disk_connection.backup(mem_connection) # Requires Python 3.7+
disk_connection.close()

# 添加到缓存并返回
memory_dbs[db_file] = mem_connection
return mem_connection


def execute_many(query_text: str, query_parameters: Any, db_file: str, use_mem = True) -> Generator[sqlite3.Row, None, None]:
"""
Runs a query on a specified Sqlite DB file, preferably using an in-memory copy for improved speed.
:param query_text: The query to run.
:param query_parameters: The parameters to provide to the query.
:param db_file: The DB file on which to run the query.
:param db_file: The DB file to use, copying to memory if not already done.
:return: A generator of rows emitted from the query.
"""
# Caching a connection saves around 600 uS for local databases.
# By making it stateless, we get isolation, which is a huge plus.
connection = sqlite3.connect(db_file)
if use_mem:
connection = get_or_copy_db_to_memory(db_file)
else:
connection = sqlite3.connect(db_file)

connection.row_factory = sqlite3.Row
cursor = connection.cursor()

try:
cursor.execute(query_text, query_parameters)

for row in cursor:
yield row
finally:
cursor.close()
connection.close()

# Do not close the connection here to reuse it

def execute_one(query_text: str, query_parameters: Any, db_file: str) -> Optional[sqlite3.Row]:
"""
Runs a query with the provided arguments on a specified Sqlite DB file.
Validates that the query returns at most one row.
Runs a query on a specified Sqlite DB file, preferably using an in-memory copy for improved speed.
:param query_text: The query to run.
:param query_parameters: The parameters to provide to the query.
:param db_file: The DB file on which to run the query.
:param db_file: The DB file to use, copying to memory if not already done.
:return: The returned row, if it exists. None otherwise.
"""
# Caching a connection saves around 600 uS for local databases.
# By making it stateless, we get isolation, which is a huge plus.
connection = sqlite3.connect(db_file)
connection = get_or_copy_db_to_memory(db_file)
connection.row_factory = sqlite3.Row
cursor = connection.cursor()

try:
cursor.execute(query_text, query_parameters)

result: Optional[sqlite3.Row] = cursor.fetchone()

# Check for more rows. If more exist, throw an error.
if result is not None and cursor.fetchone() is not None:
raise RuntimeError("execute_one query returned multiple rows.")

return result
finally:
cursor.close()
connection.close()
# Do not close the connection here to reuse it
34 changes: 30 additions & 4 deletions nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from functools import cached_property
from pathlib import Path
import time
from typing import Any, Generator, List, Optional, Set, Tuple, Type, cast

from nuplan.common.actor_state.ego_state import EgoState
Expand All @@ -24,6 +25,7 @@
get_roadblock_ids_for_lidarpc_token_from_db,
get_sampled_ego_states_from_db,
get_sampled_lidarpcs_from_db,
get_sampled_lidarpcs_from_db_batch,
get_sensor_data_from_sensor_data_tokens_from_db,
get_sensor_data_token_timestamp_from_db,
get_sensor_transform_matrix_for_sensor_data_token_from_db,
Expand Down Expand Up @@ -361,11 +363,22 @@ def get_future_tracked_objects(
time_horizon: float,
num_samples: Optional[int] = None,
future_trajectory_sampling: Optional[TrajectorySampling] = None,
) -> Generator[DetectionsTracks, None, None]:
) -> List[DetectionsTracks]:
start_time = time.time()
"""Inherited, see superclass."""
# TODO: This can be made even more efficient with a batch query
for lidar_pc in self._find_matching_lidar_pcs(iteration, num_samples, time_horizon, True):
yield DetectionsTracks(extract_tracked_objects(lidar_pc.token, self._log_file, future_trajectory_sampling))
lidar_pcs = self._find_matching_lidar_pcs_batch(iteration, num_samples, time_horizon, True)
mid_time = time.time()
print(f'执行 _find_matching_lidar_pcs_batch 用时: {(mid_time - start_time) * 1000} 毫秒')
detections_tracks = []
detections_tracks = [
DetectionsTracks(extract_tracked_objects(lidar_pc.token, self._log_file, future_trajectory_sampling))
for lidar_pc in lidar_pcs
]
end_time = time.time()
print(f'生成所有 DetectionsTracks 对象用时: {(end_time - mid_time) * 1000} 毫秒')
print(f'总函数执行用时: {(end_time - start_time) * 1000} 毫秒')
return detections_tracks


def get_past_sensors(
self,
Expand Down Expand Up @@ -446,6 +459,19 @@ def _find_matching_lidar_pcs(
self._log_file, self._lidarpc_tokens[iteration], get_lidarpc_sensor_data(), indices, look_into_future
),
)

def _find_matching_lidar_pcs_batch(
self, iteration: int, num_samples: Optional[int], time_horizon: float, look_into_future: bool
) -> List[LidarPc]:
num_samples = num_samples if num_samples else int(time_horizon / self.database_interval)
indices = sample_indices_with_time_horizon(num_samples, time_horizon, self._database_row_interval)

# 将生成器转换为批量查询
lidarpcs = get_sampled_lidarpcs_from_db_batch(
self._log_file, self._lidarpc_tokens[iteration], get_lidarpc_sensor_data(), indices, look_into_future
)
return list(lidarpcs) # 确保返回一个列表


def _extract_expert_trajectory(self, max_future_seconds: int = 60) -> Generator[EgoState, None, None]:
"""
Expand Down
47 changes: 26 additions & 21 deletions nuplan/planning/scenario_builder/nuplan_db/nuplan_scenario_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nuplan.database.nuplan_db.nuplan_db_utils import SensorDataSource, get_lidarpc_sensor_data
from nuplan.database.nuplan_db.nuplan_scenario_queries import (
get_future_waypoints_for_agents_from_db,
get_future_waypoints_for_agents_from_db_optimized,
get_sampled_sensor_tokens_in_time_window_from_db,
get_sensor_data_token_timestamp_from_db,
get_tracked_objects_for_lidarpc_token_from_db,
Expand Down Expand Up @@ -336,50 +337,54 @@ def extract_tracked_objects(
future_trajectory_sampling: Optional[TrajectorySampling] = None,
) -> TrackedObjects:
"""
Extracts all boxes from a lidarpc.
:param lidar_pc: Input lidarpc.
:param future_trajectory_sampling: If provided, the future trajectory sampling to use for future waypoints.
:return: Tracked objects contained in the lidarpc.
Extracts all boxes from a lidarpc, considering future trajectory sampling if provided.
"""
tracked_objects: List[TrackedObject] = []
agent_indexes: Dict[str, int] = {}
agent_future_trajectories: Dict[str, List[Waypoint]] = {}

# 获取当前lidar点云对应的所有追踪对象
for idx, tracked_object in enumerate(get_tracked_objects_for_lidarpc_token_from_db(log_file, token)):
if future_trajectory_sampling and isinstance(tracked_object, Agent):
agent_indexes[tracked_object.metadata.track_token] = idx
agent_future_trajectories[tracked_object.metadata.track_token] = []
tracked_objects.append(tracked_object)

if future_trajectory_sampling and len(tracked_objects) > 0:
if future_trajectory_sampling:
timestamp_time = get_sensor_data_token_timestamp_from_db(log_file, get_lidarpc_sensor_data(), token)
if timestamp_time is None:
return TrackedObjects(tracked_objects=tracked_objects)

end_time = timestamp_time + int(
1e6 * (future_trajectory_sampling.time_horizon + future_trajectory_sampling.interval_length)
)

# TODO: This is somewhat inefficient because the resampling should happen in SQL layer
for track_token, waypoint in get_future_waypoints_for_agents_from_db(
log_file, list(agent_indexes.keys()), timestamp_time, end_time
):
# 使用优化后的方式获取未来轨迹点
future_waypoints = get_future_waypoints_for_agents_from_db_optimized(
log_file, list(agent_indexes.keys()), timestamp_time, future_trajectory_sampling
)

# 重新组织未来轨迹点数据,按照追踪对象的token组织
agent_future_trajectories = {track_token: [] for track_token in agent_indexes}
for track_token, waypoint in future_waypoints:
agent_future_trajectories[track_token].append(waypoint)

for key in agent_future_trajectories:
# We can only interpolate waypoints if there is more than one in the future.
if len(agent_future_trajectories[key]) == 1:
tracked_objects[agent_indexes[key]]._predictions = [
PredictedTrajectory(1.0, agent_future_trajectories[key])
]
elif len(agent_future_trajectories[key]) > 1:
tracked_objects[agent_indexes[key]]._predictions = [
# 根据获取到的未来轨迹点更新追踪对象的预测轨迹
for track_token, waypoints in agent_future_trajectories.items():
idx = agent_indexes[track_token]
if len(waypoints) > 1: # 只有当存在多个未来轨迹点时才进行插值
tracked_objects[idx]._predictions = [
PredictedTrajectory(
1.0,
1.0, # 假设置信度为1.0
interpolate_future_waypoints(
agent_future_trajectories[key],
waypoints,
future_trajectory_sampling.time_horizon,
future_trajectory_sampling.interval_length,
),
)
]
elif len(waypoints) == 1:
tracked_objects[idx]._predictions = [
PredictedTrajectory(1.0, waypoints)
]

return TrackedObjects(tracked_objects=tracked_objects)

Expand Down