Skip to content

Commit

Permalink
Fix timm incompatibility
Browse files Browse the repository at this point in the history
Backports isl-org/MiDaS#234
Fixes #319
Fixes #323
  • Loading branch information
semjon00 committed Aug 28, 2023
1 parent 80f9fa3 commit d412f85
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
10 changes: 5 additions & 5 deletions dmidas/backbones/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_rel_pos_bias(self, window_size):
old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear")
new_sub_table = F.interpolate(old_sub_table, size=(int(new_height),int(new_width)), mode="bilinear")
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

new_relative_position_bias_table = torch.cat(
Expand Down Expand Up @@ -96,12 +96,12 @@ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tenso
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
"""
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution,
shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x


Expand Down
4 changes: 3 additions & 1 deletion dzoedepth/models/base_models/midas.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_m
kwargs = MidasCore.parse_img_size(kwargs)
img_size = kwargs.pop("img_size", [384, 384])
print("img_size", img_size)
midas = torch.hub.load("intel-isl/MiDaS", midas_model_type,
# TODO: use locally-bundled midas
# The repo should be changed back to isl-org/MiDaS once this MR lands
midas = torch.hub.load("AyaanShah2204/MiDaS", midas_model_type,
pretrained=use_pretrained_midas, force_reload=force_reload)
kwargs.update({'keep_aspect_ratio': force_keep_ar})
midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features,
Expand Down
6 changes: 3 additions & 3 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def ensure(module_name, min_version=None):
msg = f'{requirement} requirement for depthmap script'
launch.run_pip(cmd, msg)

if not launch.is_installed("timm"): #0.6.7 # For midas
launch.run_pip('install --force-reinstall "timm==0.6.12"', "timm requirement for depthmap script")

ensure('timm', '0.9.2') # Just in case

ensure('matplotlib')

Expand All @@ -47,8 +47,8 @@ def ensure(module_name, min_version=None):

if not launch.is_installed("networkx"):
launch.run_pip('install install "networkx==2.5"', "networkx requirement for depthmap script")

if platform.system() == 'Windows':
ensure('pyqt5')

if platform.system() == 'Darwin':
ensure('pyqt6')

0 comments on commit d412f85

Please sign in to comment.