Skip to content

Commit

Permalink
Add training.
Browse files Browse the repository at this point in the history
  • Loading branch information
dunbar12138 committed Apr 20, 2023
1 parent 637ab04 commit c6995a4
Show file tree
Hide file tree
Showing 18 changed files with 3,218 additions and 280 deletions.
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,45 @@ You can also launch an interactive demo of 3D editing.

### Training

Code is coming soon.

We provide an example training script at `train_scripts/afhq_seg.sh`:
```
python train.py --outdir=<log_dir> \
--cfg=afhq --data=data/afhq_v2_train_cat_512.zip \
--mask_data=data/afhqcat_seg_6c.zip \
--data_type=seg --semantic_channels=6 \
--render_mask=True --dis_mask=True \
--neural_rendering_resolution_initial=128 \
--resume=<EG3D-checkpoints>/afhqcats512-128.pkl \
--gpus=2 --batch=4 --mbstd-group=2 \
--gamma=5 --gen_pose_cond=True \
--random_c_prob=0.5 \
--lambda_d_semantic=0.1 \
--lambda_lpips=1 \
--lambda_cross_view=1e-4 \
--only_raw_recons=True \
--wandb_log=False
```
Training parameters:
- `outdir`: The directory to save checkpoints and logs.
- `cfg`: Choose from [afhq, celeba, shapenet].
- `data`: RGB data file.
- `mask_data`: label map data file.
- `data_type`: Choose from [seg, edge]. Specify the `semantic_channels` if using `seg`.
- `render_mask`: Whether to render label maps along with RGB.
- `dis_mask`: Whether to use a GAN loss on rendered label maps.
- `neural_rendering_resolution_initial`: The resolution of NeRF outputs.
- `resume`: We partially initialize our network with EG3D pretrained checkpoints (download [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/eg3d)).
- `gpus`, `batch`, `mbstd-group`: Parameters for batch size and multi-gpu training.
- `gen_pose_cond`: Whether to condition the generation on camera poses.
- `random_c_prob`: Probablity of sampling random poses for training.
- `lambda_d_semantic`: The weight of GAN loss on label maps.
- `lambda_lpips`: The weight of RGB LPIPS loss.
- `lambda_cross_view`: The weight of cross-view consistency loss.
- `wandb_log`: Whether to use wandb log.

### Prepare your own dataset

We follow the dataset format of EG3D [here](https://github.com/NVlabs/eg3d#preparing-datasets). You can obtain the segmentation masks of your own dataset by [DINO clustering](https://github.com/ShirAmir/dino-vit-features/blob/main/part_cosegmentation.py), and obtain the edge map by [pidinet](https://github.com/hellozhuo/pidinet) and [informative drawing](https://github.com/carolineec/informative-drawings).

---

Expand Down
234 changes: 234 additions & 0 deletions applications/edge2cat.ipynb

Large diffs are not rendered by default.

52 changes: 40 additions & 12 deletions applications/generate_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_ra
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
with torch.no_grad():
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
# frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
frames.append(image_color)
frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))

return frames, frames_label

def render_video_edge2car(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
def render_video_edge(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
frames, frames_label = [], []

for frame_idx in tqdm(range(num_frames)):
Expand All @@ -81,6 +83,21 @@ def render_video_edge2car(G, ws, intrinsics, num_frames = 120, pitch_range = np.

return frames, frames_label

def render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = np.pi / 2, yaw_range = np.pi, neural_rendering_resolution = 64, device='cuda'):
frames, frames_label = [], []

for frame_idx in tqdm(range(num_frames)):
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
with torch.no_grad():
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
frames_label.append(((out['semantic'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8)[0])

return frames, frames_label

def main():
# Parse arguments
parser = argparse.ArgumentParser(description='Generate samples from a trained model')
Expand All @@ -100,7 +117,7 @@ def main():
with dnnlib.util.open_url(args.network) as f:
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)

if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
if args.cfg == 'seg2cat' or args.cfg == 'seg2face' or args.cfg == 'edge2cat':
neural_rendering_resolution = 128
pitch_range, yaw_range = 0.25, 0.35
data_type = 'seg'
Expand Down Expand Up @@ -135,9 +152,15 @@ def main():

# Save the visualized input label map
PIL.Image.fromarray(color_mask(input_label[0,0].cpu().numpy()).astype(np.uint8)).save(save_dir / f'{args.cfg}_input.png')
elif args.cfg == 'edge2car':
input_label = np.array(input_label).astype(np.float32)[..., 0]
input_label = -(torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
input_label = np.array(input_label).astype(np.float32)
if input_label.ndim == 3:
input_label = input_label[:,:,0]
print(input_label.min(), input_label.max())
input_label = (torch.tensor(input_label).to(torch.float32) / 127.5 - 1).unsqueeze(0).unsqueeze(0).to(device)
plt.imshow(input_label.cpu().numpy()[0,0], cmap='gray')
plt.savefig(save_dir / f'{args.cfg}_input.png')

input_pose = forward_pose.to(device)

elif args.input_id is not None:
Expand All @@ -148,8 +171,13 @@ def main():
data_path = Path(args.data_dir) / 'cars_128.zip'
mask_data = Path(args.data_dir) / 'shapenet_car_contour.zip'
elif args.cfg == 'seg2face':
data_path = Path(args.data_dir) / 'celebamask_test.zip'
mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
# data_path = Path(args.data_dir) / 'celebamask_test.zip'
# mask_data = Path(args.data_dir) / 'celebamask_test_label.zip'
data_path = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test.zip'
mask_data = '/data2/datasets/CelebAMask_eg3d/test/celebamask_test_label.zip'
elif args.cfg == 'edge2cat':
data_path = '/data2/datasets/AFHQ_eg3d/afhq_v2_train_cat_512.zip'
mask_data = '/data2/datasets/AFHQ_eg3d/afhqcat_contour_pidinet.zip'

dataset_kwargs, dataset_name = init_conditional_dataset_kwargs(str(data_path), str(mask_data), data_type)
dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
Expand All @@ -160,13 +188,13 @@ def main():
# Save the input label map
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
PIL.Image.fromarray(color_mask(batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')
elif args.cfg == 'edge2car':
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
PIL.Image.fromarray((255 - batch['mask'][0]).astype(np.uint8)).save(save_dir / f'{args.cfg}_{args.input_id}_input.png')

input_pose = torch.tensor(batch['pose']).unsqueeze(0).to(device)
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
input_label = torch.tensor(batch['mask']).unsqueeze(0).to(device)
elif args.cfg == 'edge2car':
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
input_label = -(torch.tensor(batch['mask']).to(torch.float32) / 127.5 - 1).unsqueeze(0).to(device)

# Generate videos
Expand All @@ -179,8 +207,8 @@ def main():
# Generate the video
if args.cfg == 'seg2cat' or args.cfg == 'seg2face':
frames, frames_label = render_video(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
elif args.cfg == 'edge2car':
frames, frames_label = render_video_edge2car(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)
elif args.cfg == 'edge2car' or args.cfg == 'edge2cat':
frames, frames_label = render_video_edge2cat(G, ws, intrinsics, num_frames = 120, pitch_range = pitch_range, yaw_range = yaw_range, neural_rendering_resolution=neural_rendering_resolution, device=device)

# Save the video
imageio.mimsave(save_dir / f'{args.cfg}_{seed}.gif', frames, fps=60)
Expand Down
Binary file added examples/example_input_edge2cat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

# empty
Loading

0 comments on commit c6995a4

Please sign in to comment.