Skip to content

Commit

Permalink
Valid samples mask inserted
Browse files Browse the repository at this point in the history
  • Loading branch information
giannipint committed Oct 1, 2021
1 parent ccfd1fd commit 1cd53d6
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pycache__/
ckpt/
.vs/
misc/__pycache__/
ckpt_old/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This repo is a **python** implementation where you can test **depth inference**
![](assets/overview.png)

## Updates
* 2021-10-1: Inference script updated with valid mask and ground truth loader
* 2021-08-13: IMPORTANT: Fixed bug in weights init: model and pre-trained weights updated
- REPLACE PREVIOUS MODEL AND WEIGHTS
* 2021-07-21: Network source code and demo released
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 20 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import numpy as np
from PIL import Image
from scipy.interpolate import interp2d

import matplotlib.pyplot as plt
import torch
Expand All @@ -11,6 +12,7 @@


def_img = 'example/001ad2ad14234e06b2d996d71bb96fc4_color.png'#
def_gt = 'example/001ad2ad14234e06b2d996d71bb96fc4_depth.png'#

def_pth ='ckpt/resnet50_m3d.pth'

Expand All @@ -20,6 +22,7 @@
parser.add_argument('--pth', required=False, default = def_pth,
help='path to load saved checkpoint.')
parser.add_argument('--img_glob', required=False, default = def_img)
parser.add_argument('--gt_depth', required=False, default = def_gt)
parser.add_argument('--no_cuda', action='store_true', default = False)

args = parser.parse_args()
Expand All @@ -33,20 +36,32 @@
#
img_pil = Image.open(args.img_glob)

full_W,full_H = img_pil.size

H, W = 512,1024

img_pil = img_pil.resize((W,H), Image.BICUBIC)
img = np.array(img_pil, np.float32)[..., :3] / 255.


with torch.no_grad():
x_img = torch.FloatTensor(img.transpose([2, 0, 1]).copy())

####predict depth
x_img = torch.FloatTensor(img.transpose([2, 0, 1]).copy())
x = x_img.unsqueeze(0)
depth = net(x.to(device))

depth = net(x.to(device))
depth_c = depth.cpu().numpy().astype(np.float32).squeeze(0)

####create valid mask for Matterport sensor
depth_gt = np.array(Image.open(args.gt_depth), np.float32)
xrange = lambda x: np.linspace(0, 1, x)
f = interp2d(xrange(full_W), xrange(full_H), depth_gt, kind="linear")##kind="cubic")
depth_gt = f(xrange(W), xrange(H))
depth_gt /= 4000.0 #####matterport scale to meters
depth_mask = ((depth_gt <= 127.0) & (depth_gt > 0.)).astype(np.uint8)
depth_mask = torch.FloatTensor(depth_mask.copy())


depth_c = (depth_mask*depth.cpu()).numpy().astype(np.float32).squeeze(0)


plt.figure(0)
plt.title('prediction')
Expand Down
74 changes: 73 additions & 1 deletion slice_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(self, x):
feature = self.slicing_module(conv_list, x.shape[3])

feature = feature.permute(2, 0, 1)

output, hidden = self.bi_rnn(feature)

output = self.drop_out(output)
Expand All @@ -236,3 +236,75 @@ def forward(self, x):


return depth


if __name__ == '__main__':
print('testing SliceNet')

device = torch.device('cuda')

net = SliceNet('resnet50',full_size = True).to(device)

pytorch_total_params = sum(p.numel() for p in net.parameters())

for name, param in net.named_parameters():
if param.requires_grad:
print(name, param.numel())

print('pytorch_total_params', pytorch_total_params)

pytorch_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)

print('pytorch_trainable_params', pytorch_trainable_params)

decoder_params = 0

for name, param in net.named_parameters():
if (param.requires_grad and ('decoder' in name) ):
print(name, param.numel())
decoder_params += param.numel()

print('equi decoder parameters', decoder_params)

rnn_params = 0

for name, param in net.named_parameters():
if (param.requires_grad and ('rnn' in name) ):
print(name, param.numel())
rnn_params += param.numel()

print('rnn decoder parameters', rnn_params)

h_encoder_params = 0

for name, param in net.named_parameters():
if (param.requires_grad and ('reduce_height_module' in name) ):
print(name, param.numel())
h_encoder_params += param.numel()

print('height ecoder parameters', h_encoder_params)

encoder_params = 0

for name, param in net.named_parameters():
if (param.requires_grad and ('feature_extractor' in name) ):
print(name, param.numel())
encoder_params += param.numel()

print('resnet encoder parameters', encoder_params)

##batch = torch.ones(1, 3, 256, 512).to(device)
batch = torch.ones(1, 3, 512, 1024).to(device)

##with torch.no_grad():
torch.cuda.synchronize()
t0 = time.time()
out_depth = net(batch)
torch.cuda.synchronize()
elapsed_fp = time.time()-t0

print('time cost',elapsed_fp)

print('out_depth shape', out_depth.shape)

print('test done')

0 comments on commit 1cd53d6

Please sign in to comment.