{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from conch.open_clip_custom import create_model_from_pretrained, tokenize, get_tokenizer\n", "import torch\n", "import os\n", "from PIL import Image\n", "from pathlib import Path\n", "\n", "# show all jupyter output\n", "from IPython.core.interactiveshell import InteractiveShell\n", "InteractiveShell.ast_node_interactivity = \"all\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "root = Path('../').resolve()\n", "os.chdir(root)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load model from checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, preprocess = create_model_from_pretrained(model_cfg='conch_ViT-B-16', \n", " checkpoint_path='./checkpoints/CONCH/pytorch_model.bin')\n", "_ = model.eval()\n", "\n", "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n", "model = model.to(device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Open an image and preprocess it" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = Image.open('./docs/roi1.jpg')\n", "image_tensor = preprocess(image).unsqueeze(0).to(device)\n", "\n", "# visualize thumbnail\n", "image.resize((224, 224))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load tokenizer and specify some prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = get_tokenizer()\n", "prompts = [\n", " 'photomicrograph illustrating invasive ductal carcinoma of the breast, H&E stain',\n", " 'a case of invasive lobular carcinoma as visualized using H&E stain',\n", " 'high magnification view of a breast cancer tumor, H&E stain',\n", " 'clear cell renal cell carcinoma',\n", " 'lung adenocarcinoma, H&E stain',\n", " 'IHC stain for CDX2 in a case of metastatic colorectal adenocarcinoma',\n", " 'an image of a cat',\n", " 'High-grade angiosarcoma characterized by solid areas of polygonal and spindled cells as well as necrosis',\n", " 'metastatic tumor to the lymph node, GATA3 staining',\n", " 'epidermis with follicular ostia'\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenized_prompts = tokenize(texts=prompts, tokenizer=tokenizer).to(device)\n", "tokenized_prompts.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Embed the prompts and the image and compute the cosine similarity between the image and the prompts. Note that for illustrative purposes, we only show image --> text retrieval but the reverse direction is analogous and can be performed using the same function calls. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.inference_mode():\n", " image_embedings = model.encode_image(image_tensor)\n", " text_embedings = model.encode_text(tokenized_prompts)\n", " sim_scores = (image_embedings @ text_embedings.T).squeeze(0)\n", "\n", "print(\"Ranked list of prompts based on cosine similarity with the image:\")\n", "ranked_scores, ranked_idx = torch.sort(sim_scores, descending=True)\n", "for idx, score in zip(ranked_idx, ranked_scores):\n", " print(f\"\\\"{prompts[idx]}\\\": {score:.3f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "conch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }