Skip to content

Commit

Permalink
Fixes to #6
Browse files Browse the repository at this point in the history
  • Loading branch information
fyviezhao committed Nov 24, 2021
1 parent 0b0fb7a commit 3c9d16c
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 46 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
results/
checkpoints/
*.pth
.DS_Store
test_opt.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ The use of this code and the MPV3D dataset is RESTRICTED to non-commercial resea
## Citation
If our code is helpful to your research, please cite:
```
@article{Zhao2021M3DVTONAM,
title={M3D-VTON: A Monocular-to-3D Virtual Try-On Network},
author={Fuwei Zhao and Zhenyu Xie and Michael C. Kampffmeyer and Haoye Dong and Songfang Han and Tianxiang Zheng and Tao Zhang and Xiaodan Liang},
journal={ArXiv},
year={2021},
volume={abs/2108.05126}
@InProceedings{M3D-VTON,
author = {Zhao, Fuwei and Xie, Zhenyu and Kampffmeyer, Michael and Dong, Haoye and Han, Songfang and Zheng, Tianxiang and Zhang, Tao and Liang, Xiaodan},
title = {M3D-VTON: A Monocular-to-3D Virtual Try-On Network},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {13239-13249}
}
```

33 changes: 20 additions & 13 deletions data/aligned_MPV3dDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, opt):
"""
# save the option and dataset root
BaseDataset.__init__(self, opt)
self.isTrain = opt.isTrain
self.model = opt.model
self.img_width, self.img_height = opt.img_width, opt.img_height
self.radius = opt.radius
Expand Down Expand Up @@ -174,29 +175,35 @@ def __getitem__(self, index):
imhal_sobelx, imhal_sobely = '', ''

# im depth (front)
imfd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('.png', '_depth.npy')))
imfd_m = (imfd > 0).astype(np.float32)
imfd = -1 * (2 * imfd -1) # viewport -> ndc -> world
imfd = imfd * imfd_m
imfd = torch.from_numpy(imfd).unsqueeze(0)
if self.model == 'DRM' or self.model == 'TFM':
if self.model == 'MTM' and self.isTrain:
imfd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('.png', '_depth.npy')))
imfd_m = (imfd > 0).astype(np.float32)
imfd = -1 * (2 * imfd -1) # viewport -> ndc -> world
imfd = imfd * imfd_m
imfd = torch.from_numpy(imfd).unsqueeze(0)
elif self.model == 'DRM' or self.model == 'TFM':
imfd = ''
imfd_initial = np.load(os.path.join(self.warproot, 'initial-depth', im_name.replace('whole_front.png', 'initial_front_depth.npy')))
imfd_initial = torch.from_numpy(imfd_initial).unsqueeze(0)
else:
imfd = ''
imfd_initial = ''


# im depth (back)
imbd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('front.png', 'back_depth.npy')))
imbd = np.flip(imbd, axis = 1) # align with imfd
imbd_m = (imbd > 0).astype(np.float32)
imbd = 2 * imbd -1 # viewport -> ndc -> world
imbd = imbd * imbd_m
imbd = torch.from_numpy(imbd).unsqueeze(0)
if self.model == 'DRM':
if self.model == 'MTM' and self.isTrain:
imbd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('front.png', 'back_depth.npy')))
imbd = np.flip(imbd, axis = 1) # align with imfd
imbd_m = (imbd > 0).astype(np.float32)
imbd = 2 * imbd -1 # viewport -> ndc -> world
imbd = imbd * imbd_m
imbd = torch.from_numpy(imbd).unsqueeze(0)
elif self.model == 'DRM':
imbd = ''
imbd_initial = np.load(os.path.join(self.warproot, 'initial-depth', im_name.replace('whole_front.png', 'initial_back_depth.npy')))
imbd_initial = torch.from_numpy(imbd_initial).unsqueeze(0)
else:
imbd = ''
imbd_initial = ''

# load pose points
Expand Down
29 changes: 18 additions & 11 deletions data/unaligned_MPV3dDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, opt):
"""
# save the option and dataset root
BaseDataset.__init__(self, opt)
self.isTrain = opt.isTrain
self.model = opt.model
self.img_width, self.img_height = opt.img_width, opt.img_height
self.radius = opt.radius
Expand Down Expand Up @@ -175,29 +176,35 @@ def __getitem__(self, index):
imhal_sobelx, imhal_sobely = '', ''

# im depth (front)
imfd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('.png', '_depth.npy')))
imfd_m = (imfd > 0).astype(np.float32)
imfd = -1 * (2 * imfd -1) # viewport -> ndc -> world
imfd = imfd * imfd_m
imfd = torch.from_numpy(imfd).unsqueeze(0)
if self.model == 'MTM' and self.isTrain:
imfd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('.png', '_depth.npy')))
imfd_m = (imfd > 0).astype(np.float32)
imfd = -1 * (2 * imfd -1) # viewport -> ndc -> world
imfd = imfd * imfd_m
imfd = torch.from_numpy(imfd).unsqueeze(0)
if self.model == 'DRM' or self.model == 'TFM':
imfd = ''
imfd_initial = np.load(os.path.join(self.warproot, 'initial-depth', im_name.replace('whole_front.png', 'initial_front_depth.npy')))
imfd_initial = torch.from_numpy(imfd_initial).unsqueeze(0)
else:
imfd = ''
imfd_initial = ''


# im depth (back)
imbd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('front.png', 'back_depth.npy')))
imbd = np.flip(imbd, axis = 1) # align with imfd
imbd_m = (imbd > 0).astype(np.float32)
imbd = 2 * imbd -1 # viewport -> ndc -> world
imbd = imbd * imbd_m
imbd = torch.from_numpy(imbd).unsqueeze(0)
if self.model == 'MTM' and self.isTrain:
imbd = np.load(os.path.join(self.dataroot, 'depth', im_name.replace('front.png', 'back_depth.npy')))
imbd = np.flip(imbd, axis = 1) # align with imfd
imbd_m = (imbd > 0).astype(np.float32)
imbd = 2 * imbd -1 # viewport -> ndc -> world
imbd = imbd * imbd_m
imbd = torch.from_numpy(imbd).unsqueeze(0)
if self.model == 'DRM':
imbd = ''
imbd_initial = np.load(os.path.join(self.warproot, 'initial-depth', im_name.replace('whole_front.png', 'initial_back_depth.npy')))
imbd_initial = torch.from_numpy(imbd_initial).unsqueeze(0)
else:
imbd = ''
imbd_initial = ''

# load pose points
Expand Down
19 changes: 10 additions & 9 deletions models/DRM_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,17 @@ def set_input(self, input):
self.imhal_sobely = input['imhal_sobely'].to(self.device) # for input
self.c_sobelx = input['cloth_sobelx'].to(self.device) # for input
self.c_sobely = input['cloth_sobely'].to(self.device) # for input
self.imfd = input['person_fdepth'].to(self.device) # for ground truth
self.imbd = input['person_bdepth'].to(self.device) # for ground truth
if self.use_grad_loss:
self.fgrad = self.compute_grad(self.imfd) # for ground truth
self.bgrad = self.compute_grad(self.imbd) # for ground truth
if self.isTrain:
self.imfd = input['person_fdepth'].to(self.device) # for ground truth
self.imbd = input['person_bdepth'].to(self.device) # for ground truth
if self.use_grad_loss:
self.fgrad = self.compute_grad(self.imfd) # for ground truth
self.bgrad = self.compute_grad(self.imbd) # for ground truth

if self.use_normal_loss or self.use_gan_loss:
self.im_mask = input['person_mask'].to(self.device) # for input
self.imfn = util.depth2normal_ortho(self.imfd).to(self.device) # for ground truth
self.imbn = util.depth2normal_ortho(self.imbd).to(self.device) # for ground truth
if self.use_normal_loss or self.use_gan_loss:
self.im_mask = input['person_mask'].to(self.device) # for input
self.imfn = util.depth2normal_ortho(self.imfd).to(self.device) # for ground truth
self.imbn = util.depth2normal_ortho(self.imbd).to(self.device) # for ground truth

def forward(self):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
Expand Down
16 changes: 10 additions & 6 deletions models/MTM_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,19 @@ def set_input(self, input):
self.c_name = input['c_name'] # meta info
self.agnostic = input['agnostic'].to(self.device) # for input
self.c = input['cloth'].to(self.device) # for input
self.im_c = input['parse_cloth'].to(self.device) # for ground truth
self.fdepth_gt = input['person_fdepth'].to(self.device) # for ground truth
self.bdepth_gt = input['person_bdepth'].to(self.device) # for ground truth
self.segmt_gt = input['person_parse'].long().to(self.device) # for ground truth
self.cm = input['cloth_mask'].to(self.device) # for visual
self.im = input['person'].to(self.device) # for visual
self.im_shape = input['person_shape'] # for visual
self.im_hhl = input['head_hand_lower'] # for visual
self.pose = input['pose'] # for visual
self.im_g = input['grid_image'].to(self.device) # for visual
self.im_c = input['parse_cloth'].to(self.device) # for ground truth
self.segmt_gt = input['person_parse'].long().to(self.device) # for ground truth
if self.isTrain:
self.fdepth_gt = input['person_fdepth'].to(self.device) # for ground truth
self.bdepth_gt = input['person_bdepth'].to(self.device) # for ground truth




def forward(self):
Expand All @@ -159,8 +162,9 @@ def forward(self):
self.fdepth_pred, self.bdepth_pred = torch.split(self.output['depth'], [1,1], 1)
self.fdepth_pred = torch.tanh(self.fdepth_pred)
self.bdepth_pred = torch.tanh(self.bdepth_pred)
self.fdepth_diff = self.fdepth_pred - self.fdepth_gt # just for visual
self.bdepth_diff = self.bdepth_pred - self.bdepth_gt # fust for visual
if self.isTrain:
self.fdepth_diff = self.fdepth_pred - self.fdepth_gt # just for visual
self.bdepth_diff = self.bdepth_pred - self.bdepth_gt # fust for visual

if self.output['segmt'] is not None:
self.segmt_pred = self.output['segmt']
Expand Down
4 changes: 3 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import time
import numpy as np
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from options.test_options import TestOptions
from data import create_dataset
Expand Down Expand Up @@ -132,4 +134,4 @@
bnormal_vis = (bnormal_vis * 255).astype(np.uint8)
bnormal_pil = Image.fromarray(bnormal_vis)
bnormal_pil.save(os.path.join(results_dir, 'final-normal-vis', im_name.replace('front.png','back_normal.png')))
print(f'\nTest {opt.model} down.')
print(f'\nTesting {opt.model} finished.')

0 comments on commit 3c9d16c

Please sign in to comment.