In [None]:
from conch.open_clip_custom import create_model_from_pretrained, tokenize, get_tokenizer
import torch
import os
from PIL import Image
from pathlib import Path

# show all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
root = Path('../').resolve()
os.chdir(root)

Load model from checkpoint

In [None]:
model, preprocess = create_model_from_pretrained(model_cfg='conch_ViT-B-16', 
                                                 checkpoint_path='./checkpoints/CONCH/pytorch_model.bin')
_ = model.eval()

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)

Open an image and preprocess it

In [None]:
image = Image.open('./docs/roi1.jpg')
image_tensor = preprocess(image).unsqueeze(0).to(device)

# visualize thumbnail
image.resize((224, 224))

Load tokenizer and specify some prompts.

In [None]:
tokenizer = get_tokenizer()
prompts = [
           'photomicrograph illustrating invasive ductal carcinoma of the breast, H&E stain',
           'a case of invasive lobular carcinoma as visualized using H&E stain',
           'high magnification view of a breast cancer tumor, H&E stain',
           'clear cell renal cell carcinoma',
           'lung adenocarcinoma, H&E stain',
           'IHC stain for CDX2 in a case of metastatic colorectal adenocarcinoma',
           'an image of a cat',
           'High-grade angiosarcoma characterized by solid areas of polygonal and spindled cells as well as necrosis',
           'metastatic tumor to the lymph node, GATA3 staining',
           'epidermis with follicular ostia'
           ]

In [None]:
tokenized_prompts = tokenize(texts=prompts, tokenizer=tokenizer).to(device)
tokenized_prompts.shape

Embed the prompts and the image and compute the cosine similarity between the image and the prompts. Note that for illustrative purposes, we only show image --> text retrieval but the reverse direction is analogous and can be performed using the same function calls. 

In [None]:
with torch.inference_mode():
    image_embedings = model.encode_image(image_tensor)
    text_embedings = model.encode_text(tokenized_prompts)
    sim_scores = (image_embedings @ text_embedings.T).squeeze(0)

print("Ranked list of prompts based on cosine similarity with the image:")
ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)
for idx, score in zip(ranked_idx, ranked_scores):
    print(f"\"{prompts[idx]}\": {score:.3f}")