import sys sys.path.append("../CricaVPR") import torch from PIL import Image from collections import OrderedDict import torchvision.transforms import network as crica_net_lib class CricaModel: def __init__(self): self.conf = {"name": "crica"} model = crica_net_lib.CricaVPRNet() checkpoint = torch.load("../CricaVPR/CricaVPR.pth") if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint if list(state_dict.keys())[0].startswith("module"): state_dict = OrderedDict( {k.replace("module.", ""): v for (k, v) in state_dict.items()} ) model.load_state_dict(state_dict) model = model.to("cuda") model.eval() self.model = model self.transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def process(self, name): image = Image.open(name).convert("RGB") image = self.transform(image) image = torchvision.transforms.functional.resize(image, (224, 224)) image_descriptor = self.model(image.unsqueeze(0).cuda()) image_descriptor = image_descriptor.squeeze().cpu().numpy() # 10752 return image_descriptor