diff --git a/.gitmodules b/.gitmodules index 7a9c595..d61f3e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -64,3 +64,6 @@ [submodule "third_party/mast3r"] path = third_party/mast3r url = https://github.com/naver/mast3r +[submodule "third_party/pram"] + path = third_party/pram + url = https://github.com/agipro/pram.git diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 9fb76ed..2a8a004 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,11 +6,9 @@ from .. import logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.sfd2 import load_sfd2 - +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 class SFD2(BaseModel): default_conf = { @@ -26,7 +23,7 @@ def _init(self, conf): self.norm_rgb = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) - model_fn = pram_path / "weights" / self.conf["model_name"] + model_fn = tp_path / "pram" / "weights" / self.conf["model_name"] self.net = load_sfd2(weight_path=model_fn).eval() logger.info("Load SFD2 model done.") diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index ca64980..05c3cb9 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- import sys from pathlib import Path @@ -7,10 +6,9 @@ from .. import DEVICE, logger from ..utils.base_model import BaseModel -pram_path = Path(__file__).parent / "../../third_party/pram" -sys.path.append(str(pram_path)) - -from nets.gml import GML +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML class IMP(BaseModel): @@ -33,7 +31,8 @@ class IMP(BaseModel): def _init(self, conf): self.conf = {**self.default_conf, **conf} - weight_path = pram_path / "weights" / self.conf["model_name"] + weight_path = tp_path / "pram" / "weights" / self.conf["model_name"] + # self.net = nets.gml(self.conf).eval().to(DEVICE) self.net = GML(self.conf).eval().to(DEVICE) self.net.load_state_dict( torch.load(weight_path, map_location="cpu")["model"], strict=True diff --git a/third_party/pram b/third_party/pram new file mode 160000 index 0000000..742ff42 --- /dev/null +++ b/third_party/pram @@ -0,0 +1 @@ +Subproject commit 742ff4241105bfaee0039e01556c79fb9be5c8b8