-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1,642 changed files
with
280,297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
__pycache__ | ||
*/__pycache__/ | ||
*/.ipynb_checkpoints/ | ||
*/.idea/ | ||
*/.vscode/ | ||
*/.pytest_cache/ | ||
*/.git/ | ||
*/.gitignore | ||
*/.DS_Store | ||
*/.env | ||
*/.env.example | ||
*/.envrc | ||
*/.venv/ | ||
logs/ | ||
preprocess | ||
scripts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# PFGS: High Fidelity Point Cloud Rendering via Feature Splatting | ||
<!-- data:image/s3,"s3://crabby-images/d0ea7/d0ea7d0e0372ef53c315891803db264853c8daae" alt="issues" | ||
data:image/s3,"s3://crabby-images/60a21/60a214c09c1fc59a529ead441513078b04df2304" alt="forks" | ||
data:image/s3,"s3://crabby-images/1e4ac/1e4acde11a179bd73c39059693701995e0dc88b7" alt="stars" --> | ||
|
||
> [PFGS: High Fidelity Point Cloud Rendering via Feature Splatting](https://arxiv.org/abs/2407.03857) | ||
> Jiaxu Wang<sup>†</sup>, Ziyi Zhang<sup>†</sup>, Junhao He, Renjing Xu* | ||
> ECCV 2024 | ||
> | ||
data:image/s3,"s3://crabby-images/0a843/0a8432219db4fa8d979db1df4f971efe0b9280a5" alt="framework_img" | ||
|
||
|
||
If you found this project useful, please [cite](#citation) us in your paper, this is the greatest support for us. | ||
|
||
### Requirements (Tested on 1 * RTX3090) | ||
- Linux | ||
- Python == 3.8 | ||
- Pytorch == 1.13.0 | ||
- CUDA == 11.7 | ||
|
||
## Installation | ||
|
||
### Install from environment.yml | ||
You can directly install the requirements through: | ||
```sh | ||
$ conda env create -f environment.yml | ||
``` | ||
|
||
### Or install packages seperately | ||
* Create Environment | ||
```sh | ||
$ conda create --name PFGS python=3.8 | ||
$ conda activate PFGS | ||
``` | ||
|
||
* Pytorch (Please first check your cuda version) | ||
```sh | ||
$ conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia | ||
``` | ||
* Other python packages: open3d, opencv-python, etc. | ||
|
||
### Gaussian Rasterization with High-dimensional Features | ||
```shell | ||
pip install ./submodules/diff-gaussian-rasterization | ||
``` | ||
You can customize `NUM_SEMANTIC_CHANNELS` in `submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h` for any number of feature dimension that you want: | ||
|
||
### Install third_party | ||
|
||
## Dataset | ||
#### ScanNet: | ||
- Download and extract data from original [ScanNet-V2 preprocess](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python). | ||
|
||
- Dataset structure: | ||
``` | ||
── scannet | ||
└── scene0000_00 | ||
├── pose | ||
└──1.txt | ||
├── intrinsic | ||
└──*.txt | ||
├── color | ||
└──1.jpg | ||
└── scene0000_00_vh_clean_2.ply | ||
└── images.txt | ||
└── scene0000_01 | ||
``` | ||
- [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip) | ||
#### DTU: | ||
- We reorganize the original datasets in our own format. Here we provide a demonstration of the test set of DTU, which can be downloaded [here](https://1drv.ms/u/c/747194122a3acf02/EdwjDcTXBwpAmyKqDEqjsZMBiUoxXpJ2o1QCYdt8WmMGOA?e=nvceS7) | ||
- [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip) | ||
#### THuman2: | ||
- Download 3D model and extract data from original [THuman2](https://github.com/ytrock/THuman2.0-Dataset). | ||
- Render 36 views based on each 3D model and sparse sample points(8w) on the surface of the model by Blender. | ||
- [Demo](https://1drv.ms/u/c/747194122a3acf02/EbCeCGAeY7hKgW28xfp3XvUB7snppGkG7dnumzg-eW7lVg?e=fanaHb) and [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip) | ||
## Train Stage 1 | ||
#### ScanNet: | ||
```shell | ||
python train_stage1.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage1 --img_wh 640 512 | ||
``` | ||
#### DTU: | ||
```shell | ||
python train_stage1.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage1 --img_wh 640 512 | ||
``` | ||
#### THuman2: | ||
```shell | ||
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1 --img_wh 512 512 --scale_max 0.0001 | ||
``` | ||
|
||
## Train Stage 2 | ||
#### ScanNet: | ||
```shell | ||
python train_stage2.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage2 --img_wh 640 512 --ckpt_stage1 $ckpt_stage1_path | ||
``` | ||
#### DTU: | ||
```shell | ||
python train_stage2.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage2 --img_wh 640 512 --ckpt_stage1 $ckpt_stage1_path | ||
``` | ||
#### THuman2: | ||
```shell | ||
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1 --img_wh 512 512 --scale_max 0.0001 --ckpt_stage1 $ckpt_stage1_path | ||
``` | ||
|
||
## Eval | ||
#### ScanNet: | ||
```shell | ||
python train_stage2.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage2_eval --img_wh 640 512 --resume_path $ckpt_stage2_path --val_mode test | ||
``` | ||
#### DTU: | ||
```shell | ||
python train_stage2.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage2_eval --img_wh 640 512 --resume_path $ckpt_stage2_path --val_mode test | ||
``` | ||
#### THuman2: | ||
```shell | ||
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1_eval --img_wh 512 512 --scale_max 0.0001 --resume_path $ckpt_stage2_path --val_mode test | ||
``` | ||
The results will be saved in ./log/$exp_name | ||
|
||
## Acknowledgements | ||
In this repository, we have used codes or datasets from the following repositories. | ||
We thank all the authors for sharing great codes or datasets. | ||
- [DTU](https://roboimagedata.compute.dtu.dk/?page_id=36) | ||
- [ScanNet](https://github.com/ScanNet/ScanNet) | ||
- [THuman2](https://github.com/ytrock/THuman2.0-Dataset) | ||
- [Trivol](https://github.com/dvlab-research/TriVol) | ||
- [3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) | ||
- [Feature-3DGS](https://github.com/ShijieZhou-UCLA/feature-3dgs) | ||
- [LION](https://github.com/nv-tlabs/LION) | ||
- [MIMO-UNet](https://github.com/chosj95/MIMO-UNet) | ||
|
||
## Citation | ||
``` | ||
@misc{wang2024pfgshighfidelitypoint, | ||
title={PFGS: High Fidelity Point Cloud Rendering via Feature Splatting}, | ||
author={Jiaxu Wang and Ziyi Zhang and Junhao He and Renjing Xu}, | ||
year={2024}, | ||
eprint={2407.03857}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV}, | ||
url={https://arxiv.org/abs/2407.03857}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .scannet import ScanNetDataset | ||
from .dtu import DtuDataset | ||
from .thuman2 import THuman2Dataset | ||
from .common import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
import glob | ||
import random | ||
from PIL import Image | ||
import numpy as np | ||
import cv2 | ||
import open3d as o3d | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset | ||
|
||
from kornia import create_meshgrid | ||
import torchvision | ||
import imageio | ||
import lpips | ||
|
||
|
||
|
||
def get_ray_directions_opencv(W, H, fx, fy, cx, cy): | ||
""" | ||
Get ray directions for all pixels in camera coordinate. | ||
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ | ||
ray-tracing-generating-camera-rays/standard-coordinate-systems | ||
Inputs: | ||
H, W, focal: image height, width and focal length | ||
Outputs: | ||
directions: (H, W, 3), the direction of the rays in camera coordinate | ||
""" | ||
grid = create_meshgrid(H, W, normalized_coordinates=False)[0] | ||
i, j = grid.unbind(-1) | ||
# the direction here is without +0.5 pixel centering as calibration is not so accurate | ||
# see https://github.com/bmild/nerf/issues/24 | ||
directions = \ | ||
torch.stack([(i-cx)/fx, (j-cy)/fy, torch.ones_like(i)], -1) # (H, W, 3) | ||
|
||
return directions | ||
|
||
|
||
def get_rays(directions, c2w): | ||
""" | ||
Get ray origin and normalized directions in world coordinate for all pixels in one image. | ||
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ | ||
ray-tracing-generating-camera-rays/standard-coordinate-systems | ||
Inputs: | ||
directions: (H, W, 3) precomputed ray directions in camera coordinate | ||
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate | ||
Outputs: | ||
rays_o: (H*W, 3), the origin of the rays in world coordinate | ||
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate | ||
""" | ||
# Rotate ray directions from camera coordinate to the world coordinate | ||
rays_d = directions @ c2w[:3, :3].T # (H, W, 3) | ||
rays_d = rays_d / (torch.norm(rays_d, dim=-1, keepdim=True) + 1e-8) | ||
# The origin of all rays is the camera origin in world coordinate | ||
rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) | ||
|
||
return rays_o, rays_d | ||
|
||
|
||
def trivol_collate_fn(list_data): | ||
cam_extrinsics = torch.stack([d["cam_extrinsics"] for d in list_data]) | ||
cam_intrinsics = torch.stack([d["cam_intrinsics"] for d in list_data]) | ||
rgb_batch = torch.stack([d["rgbs"] for d in list_data]) | ||
pointclouds_batch = torch.stack([d["point_cloud"] for d in list_data]) | ||
H_batch = [d["H"] for d in list_data] | ||
W_batch = [d["W"] for d in list_data] | ||
ply_path = [d["ply_path"] for d in list_data] | ||
paths = [d["paths"] for d in list_data] | ||
filenames = [d["filename"] for d in list_data] | ||
znear = torch.stack([torch.tensor(d["znear"]) for d in list_data]) | ||
zfar = torch.stack([torch.tensor(d["zfar"]) for d in list_data]) | ||
|
||
return { | ||
"cam_extrinsics": cam_extrinsics, | ||
"cam_intrinsics": cam_intrinsics, | ||
"rgbs": rgb_batch, | ||
"H": H_batch, | ||
"W": W_batch, | ||
"point_cloud": pointclouds_batch, | ||
"ply_path": ply_path, | ||
"paths":paths, | ||
"filename": filenames, | ||
"znear":znear, | ||
"zfar": zfar, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import numpy as np | ||
from PIL import Image | ||
import math | ||
|
||
def read_image(filename, max_dim= -1): | ||
"""Read image and rescale to specified max dimension (if exists) | ||
Args: | ||
filename: image input file path string | ||
max_dim: max dimension to scale down the image; keep original size if -1 | ||
Returns: | ||
Tuple of scaled image along with original image height and width | ||
""" | ||
image = Image.open(filename) | ||
# scale 0~255 to 0~1 | ||
np_image = np.array(image, dtype=np.float32) / 255.0 | ||
return np_image | ||
|
||
def read_cam_file(filename): | ||
with open(filename) as f: | ||
lines = [line.rstrip() for line in f.readlines()] | ||
# extrinsics: line [1,5), 4x4 matrix | ||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) | ||
# intrinsics: line [7-10), 3x3 matrix | ||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) | ||
# depth min and max: line 11 | ||
if len(lines) >= 12: | ||
depth_params = np.fromstring(lines[11], dtype=np.float32, sep=' ') | ||
else: | ||
depth_params = np.empty(0) | ||
|
||
return intrinsics, extrinsics, depth_params | ||
|
||
def project_points(points_3d, colors, K, RT): | ||
points_cam = RT @ np.hstack((points_3d, np.ones((len(points_3d), 1)))).T | ||
points_proj = K @ points_cam | ||
points_proj = points_proj[:2, :] / points_proj[2, :] | ||
return points_proj, colors | ||
|
||
def focal2fov(focal, pixels): | ||
return 2 * math.atan(pixels / (2 * focal)) | ||
|
||
def frustrum_clean(pcd, color, intrinsic, extrinsic, W_ori): | ||
center = np.ones((4, 1)) | ||
c = intrinsic[:2, 2][:,np.newaxis] | ||
center[:2] = c | ||
pose = np.linalg.inv(extrinsic) | ||
cam_ori = pose[:3, 3:] | ||
|
||
world_cam_center = (pose @ np.linalg.inv(intrinsic) @ center)[:3] | ||
view_dir = np.repeat((world_cam_center - cam_ori).transpose((1,0)), pcd.shape[0], axis=0) | ||
view_dir = view_dir / np.linalg.norm(view_dir, axis=1, keepdims=True) | ||
point_dirs = pcd - cam_ori.reshape(-1) | ||
point_dirs = point_dirs / np.linalg.norm(point_dirs, axis=1, keepdims=True) | ||
|
||
angles = np.arccos(np.sum(point_dirs * view_dir, axis=-1)) | ||
|
||
fov = focal2fov(intrinsic[0, 0], W_ori) / 2 | ||
filtered_indices = np.where((angles < fov)) | ||
filtered_points = pcd[filtered_indices] | ||
color_points = color[filtered_indices] | ||
return filtered_points, color_points | ||
|
||
|
||
T = np.array([[1,0,0,0], | ||
[0,-1,0,0], | ||
[0,0,-1,0], | ||
[0,0,0,1]]) | ||
|
||
def prepare_depth(depth): | ||
# adjust depth maps generated by vision blender | ||
INVALID_DEPTH = -1 | ||
depth[depth == INVALID_DEPTH] = 0 | ||
return depth | ||
|
||
def find_depth(npz_file): | ||
npz = np.load(npz_file, allow_pickle=True) | ||
depth = npz['depth_map'] | ||
depth = prepare_depth(depth) | ||
return depth | ||
|
||
def find_pose(npz_file): | ||
npz = np.load(npz_file, allow_pickle=True) | ||
poses = npz['object_poses'] | ||
for obj in poses: | ||
obj_name = obj['name'] | ||
obj_mat = obj['pose'] | ||
if obj_name == 'Camera': | ||
pose = obj_mat.astype(np.float32) | ||
break | ||
return pose @ T | ||
|
||
|
||
|
Oops, something went wrong.