Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Zerory1 committed Jul 8, 2024
1 parent fafb567 commit b63d472
Show file tree
Hide file tree
Showing 1,642 changed files with 280,297 additions and 0 deletions.
16 changes: 16 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
__pycache__
*/__pycache__/
*/.ipynb_checkpoints/
*/.idea/
*/.vscode/
*/.pytest_cache/
*/.git/
*/.gitignore
*/.DS_Store
*/.env
*/.env.example
*/.envrc
*/.venv/
logs/
preprocess
scripts
145 changes: 145 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# PFGS: High Fidelity Point Cloud Rendering via Feature Splatting
<!-- ![issues](https://img.shields.io/github/issues/Mercerai/PFGS)
![forks](https://img.shields.io/github/forks/Mercerai/PFGS?style=flat&color=orange)
![stars](https://img.shields.io/github/stars/Mercerai/PFGS?style=flat&color=red) -->

> [PFGS: High Fidelity Point Cloud Rendering via Feature Splatting](https://arxiv.org/abs/2407.03857)
> Jiaxu Wang<sup>†</sup>, Ziyi Zhang<sup>†</sup>, Junhao He, Renjing Xu*
> ECCV 2024
>
![framework_img](figs/network.png)


If you found this project useful, please [cite](#citation) us in your paper, this is the greatest support for us.

### Requirements (Tested on 1 * RTX3090)
- Linux
- Python == 3.8
- Pytorch == 1.13.0
- CUDA == 11.7

## Installation

### Install from environment.yml
You can directly install the requirements through:
```sh
$ conda env create -f environment.yml
```

### Or install packages seperately
* Create Environment
```sh
$ conda create --name PFGS python=3.8
$ conda activate PFGS
```

* Pytorch (Please first check your cuda version)
```sh
$ conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
```
* Other python packages: open3d, opencv-python, etc.

### Gaussian Rasterization with High-dimensional Features
```shell
pip install ./submodules/diff-gaussian-rasterization
```
You can customize `NUM_SEMANTIC_CHANNELS` in `submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h` for any number of feature dimension that you want:

### Install third_party

## Dataset
#### ScanNet:
- Download and extract data from original [ScanNet-V2 preprocess](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python).

- Dataset structure:
```
── scannet
└── scene0000_00
├── pose
└──1.txt
├── intrinsic
└──*.txt
├── color
└──1.jpg
└── scene0000_00_vh_clean_2.ply
└── images.txt
└── scene0000_01
```
- [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip)
#### DTU:
- We reorganize the original datasets in our own format. Here we provide a demonstration of the test set of DTU, which can be downloaded [here](https://1drv.ms/u/c/747194122a3acf02/EdwjDcTXBwpAmyKqDEqjsZMBiUoxXpJ2o1QCYdt8WmMGOA?e=nvceS7)
- [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip)
#### THuman2:
- Download 3D model and extract data from original [THuman2](https://github.com/ytrock/THuman2.0-Dataset).
- Render 36 views based on each 3D model and sparse sample points(8w) on the surface of the model by Blender.
- [Demo](https://1drv.ms/u/c/747194122a3acf02/EbCeCGAeY7hKgW28xfp3XvUB7snppGkG7dnumzg-eW7lVg?e=fanaHb) and [Pretrain](https://1drv.ms/u/c/747194122a3acf02/EQzE6ue3ZglLsUbfVP8uDk8BJa4C_sfILsqd5fjo5L4Dug?e=eslXip)
## Train Stage 1
#### ScanNet:
```shell
python train_stage1.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage1 --img_wh 640 512
```
#### DTU:
```shell
python train_stage1.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage1 --img_wh 640 512
```
#### THuman2:
```shell
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1 --img_wh 512 512 --scale_max 0.0001
```

## Train Stage 2
#### ScanNet:
```shell
python train_stage2.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage2 --img_wh 640 512 --ckpt_stage1 $ckpt_stage1_path
```
#### DTU:
```shell
python train_stage2.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage2 --img_wh 640 512 --ckpt_stage1 $ckpt_stage1_path
```
#### THuman2:
```shell
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1 --img_wh 512 512 --scale_max 0.0001 --ckpt_stage1 $ckpt_stage1_path
```

## Eval
#### ScanNet:
```shell
python train_stage2.py --dataset scannet --scene_dir $data_path --exp_name scannet_stage2_eval --img_wh 640 512 --resume_path $ckpt_stage2_path --val_mode test
```
#### DTU:
```shell
python train_stage2.py --dataset dtu --scene_dir $data_path --exp_name dtu_stage2_eval --img_wh 640 512 --resume_path $ckpt_stage2_path --val_mode test
```
#### THuman2:
```shell
python train_stage1.py --dataset thuman2 --scene_dir $data_path --exp_name thuman2_stage1_eval --img_wh 512 512 --scale_max 0.0001 --resume_path $ckpt_stage2_path --val_mode test
```
The results will be saved in ./log/$exp_name

## Acknowledgements
In this repository, we have used codes or datasets from the following repositories.
We thank all the authors for sharing great codes or datasets.
- [DTU](https://roboimagedata.compute.dtu.dk/?page_id=36)
- [ScanNet](https://github.com/ScanNet/ScanNet)
- [THuman2](https://github.com/ytrock/THuman2.0-Dataset)
- [Trivol](https://github.com/dvlab-research/TriVol)
- [3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/)
- [Feature-3DGS](https://github.com/ShijieZhou-UCLA/feature-3dgs)
- [LION](https://github.com/nv-tlabs/LION)
- [MIMO-UNet](https://github.com/chosj95/MIMO-UNet)

## Citation
```
@misc{wang2024pfgshighfidelitypoint,
title={PFGS: High Fidelity Point Cloud Rendering via Feature Splatting},
author={Jiaxu Wang and Ziyi Zhang and Junhao He and Renjing Xu},
year={2024},
eprint={2407.03857},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.03857},
}
```
4 changes: 4 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .scannet import ScanNetDataset
from .dtu import DtuDataset
from .thuman2 import THuman2Dataset
from .common import *
90 changes: 90 additions & 0 deletions datasets/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import glob
import random
from PIL import Image
import numpy as np
import cv2
import open3d as o3d

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from kornia import create_meshgrid
import torchvision
import imageio
import lpips



def get_ray_directions_opencv(W, H, fx, fy, cx, cy):
"""
Get ray directions for all pixels in camera coordinate.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
H, W, focal: image height, width and focal length
Outputs:
directions: (H, W, 3), the direction of the rays in camera coordinate
"""
grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
i, j = grid.unbind(-1)
# the direction here is without +0.5 pixel centering as calibration is not so accurate
# see https://github.com/bmild/nerf/issues/24
directions = \
torch.stack([(i-cx)/fx, (j-cy)/fy, torch.ones_like(i)], -1) # (H, W, 3)

return directions


def get_rays(directions, c2w):
"""
Get ray origin and normalized directions in world coordinate for all pixels in one image.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
directions: (H, W, 3) precomputed ray directions in camera coordinate
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
Outputs:
rays_o: (H*W, 3), the origin of the rays in world coordinate
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
"""
# Rotate ray directions from camera coordinate to the world coordinate
rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
rays_d = rays_d / (torch.norm(rays_d, dim=-1, keepdim=True) + 1e-8)
# The origin of all rays is the camera origin in world coordinate
rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)

return rays_o, rays_d


def trivol_collate_fn(list_data):
cam_extrinsics = torch.stack([d["cam_extrinsics"] for d in list_data])
cam_intrinsics = torch.stack([d["cam_intrinsics"] for d in list_data])
rgb_batch = torch.stack([d["rgbs"] for d in list_data])
pointclouds_batch = torch.stack([d["point_cloud"] for d in list_data])
H_batch = [d["H"] for d in list_data]
W_batch = [d["W"] for d in list_data]
ply_path = [d["ply_path"] for d in list_data]
paths = [d["paths"] for d in list_data]
filenames = [d["filename"] for d in list_data]
znear = torch.stack([torch.tensor(d["znear"]) for d in list_data])
zfar = torch.stack([torch.tensor(d["zfar"]) for d in list_data])

return {
"cam_extrinsics": cam_extrinsics,
"cam_intrinsics": cam_intrinsics,
"rgbs": rgb_batch,
"H": H_batch,
"W": W_batch,
"point_cloud": pointclouds_batch,
"ply_path": ply_path,
"paths":paths,
"filename": filenames,
"znear":znear,
"zfar": zfar,
}
95 changes: 95 additions & 0 deletions datasets/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
from PIL import Image
import math

def read_image(filename, max_dim= -1):
"""Read image and rescale to specified max dimension (if exists)
Args:
filename: image input file path string
max_dim: max dimension to scale down the image; keep original size if -1
Returns:
Tuple of scaled image along with original image height and width
"""
image = Image.open(filename)
# scale 0~255 to 0~1
np_image = np.array(image, dtype=np.float32) / 255.0
return np_image

def read_cam_file(filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
# depth min and max: line 11
if len(lines) >= 12:
depth_params = np.fromstring(lines[11], dtype=np.float32, sep=' ')
else:
depth_params = np.empty(0)

return intrinsics, extrinsics, depth_params

def project_points(points_3d, colors, K, RT):
points_cam = RT @ np.hstack((points_3d, np.ones((len(points_3d), 1)))).T
points_proj = K @ points_cam
points_proj = points_proj[:2, :] / points_proj[2, :]
return points_proj, colors

def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))

def frustrum_clean(pcd, color, intrinsic, extrinsic, W_ori):
center = np.ones((4, 1))
c = intrinsic[:2, 2][:,np.newaxis]
center[:2] = c
pose = np.linalg.inv(extrinsic)
cam_ori = pose[:3, 3:]

world_cam_center = (pose @ np.linalg.inv(intrinsic) @ center)[:3]
view_dir = np.repeat((world_cam_center - cam_ori).transpose((1,0)), pcd.shape[0], axis=0)
view_dir = view_dir / np.linalg.norm(view_dir, axis=1, keepdims=True)
point_dirs = pcd - cam_ori.reshape(-1)
point_dirs = point_dirs / np.linalg.norm(point_dirs, axis=1, keepdims=True)

angles = np.arccos(np.sum(point_dirs * view_dir, axis=-1))

fov = focal2fov(intrinsic[0, 0], W_ori) / 2
filtered_indices = np.where((angles < fov))
filtered_points = pcd[filtered_indices]
color_points = color[filtered_indices]
return filtered_points, color_points


T = np.array([[1,0,0,0],
[0,-1,0,0],
[0,0,-1,0],
[0,0,0,1]])

def prepare_depth(depth):
# adjust depth maps generated by vision blender
INVALID_DEPTH = -1
depth[depth == INVALID_DEPTH] = 0
return depth

def find_depth(npz_file):
npz = np.load(npz_file, allow_pickle=True)
depth = npz['depth_map']
depth = prepare_depth(depth)
return depth

def find_pose(npz_file):
npz = np.load(npz_file, allow_pickle=True)
poses = npz['object_poses']
for obj in poses:
obj_name = obj['name']
obj_mat = obj['pose']
if obj_name == 'Camera':
pose = obj_mat.astype(np.float32)
break
return pose @ T



Loading

0 comments on commit b63d472

Please sign in to comment.