CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention
Official repository of the paper:
CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention
Julian Schmidt, Julian Jordan, Franz Gritschneder and Klaus Dietmayer
Accepted at 2022 IEEE International Conference on Robotics and Automation (ICRA)
If you use our source code, please cite:
@InProceedings{schmidt2022cratpred,
author={Julian Schmidt and Julian Jordan and Franz Gritschneder and Klaus Dietmayer},
booktitle={2022 IEEE International Conference on Robotics and Automation (ICRA)},
title={CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention},
year={2022},
pages={7799--7805},}
CRAT-Pred is licensed under Creative Commons Attribution-NonCommercial 4.0 International License.
Check LICENSE for more information.
We recommend using Anaconda.
The installation is described on the following page:
https://docs.anaconda.com/anaconda/install/linux/
conda env create -f environment.yml
conda activate crat-pred
pip install git+https://github.com/argoai/argoverse-api.git
bash fetch_dataset.sh
Online and offline preprocessing is implemented. If you want to train your model offline on the preprocessed dataset, run:
python3 preprocess.py
You can also skip this step and run the preprocessing online during training.
python3 train.py
or
python3 train.py --use_preprocessed=True
Checkpoints are saved in the lightning_logs/
folder.
For accessing metrics and losses via Tensorboard, first start the server:
tensorboard --logdir lightning_logs/
Navigating to http://localhost:6006/ opens Tensorboard.
python3 test.py --weight=/path/to/checkpoint.ckpt
python3 test.py --weight=/path/to/checkpoint.ckpt --split=test