Skip to content

Commit

Permalink
update: pram
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincentqyw committed Aug 21, 2024
1 parent 1b8c5c5 commit 75fcc9d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@
[submodule "third_party/mast3r"]
path = third_party/mast3r
url = https://github.com/naver/mast3r
[submodule "third_party/pram"]
path = third_party/pram
url = https://github.com/agipro/pram.git
11 changes: 4 additions & 7 deletions hloc/extractors/sfd2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path

Expand All @@ -7,11 +6,9 @@
from .. import logger
from ..utils.base_model import BaseModel

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.sfd2 import load_sfd2

tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from pram.nets.sfd2 import load_sfd2

class SFD2(BaseModel):
default_conf = {
Expand All @@ -26,7 +23,7 @@ def _init(self, conf):
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
model_fn = pram_path / "weights" / self.conf["model_name"]
model_fn = tp_path / "pram" / "weights" / self.conf["model_name"]
self.net = load_sfd2(weight_path=model_fn).eval()

logger.info("Load SFD2 model done.")
Expand Down
11 changes: 5 additions & 6 deletions hloc/matchers/imp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: UTF-8 -*-
import sys
from pathlib import Path

Expand All @@ -7,10 +6,9 @@
from .. import DEVICE, logger
from ..utils.base_model import BaseModel

pram_path = Path(__file__).parent / "../../third_party/pram"
sys.path.append(str(pram_path))

from nets.gml import GML
tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from pram.nets.gml import GML


class IMP(BaseModel):
Expand All @@ -33,7 +31,8 @@ class IMP(BaseModel):

def _init(self, conf):
self.conf = {**self.default_conf, **conf}
weight_path = pram_path / "weights" / self.conf["model_name"]
weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
# self.net = nets.gml(self.conf).eval().to(DEVICE)
self.net = GML(self.conf).eval().to(DEVICE)
self.net.load_state_dict(
torch.load(weight_path, map_location="cpu")["model"], strict=True
Expand Down
1 change: 1 addition & 0 deletions third_party/pram
Submodule pram added at 742ff4

0 comments on commit 75fcc9d

Please sign in to comment.