-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwrapper.py
55 lines (45 loc) · 1.46 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from abc import ABC, abstractmethod
from PIL.Image import Image
from torch import Tensor
class VlmWrapper(ABC):
model: object
image_processor: object
tokenizer: object
@abstractmethod
def load_model(self, model_name: str, quantize: bool = False):
"""Load the model.
Args:
model_name (str): Model name.
quantize (bool, optional): Whether to 4-bit quantize the model. Defaults to False.
"""
pass
@abstractmethod
def prepare_inputs(
self,
image: Tensor | Image,
prompt: str,
):
"""Prepare inputs for the model.
Args:
image (Tensor | Image): Image tensor or PIL Image.
prompt (str): Prompt string.
Returns:
dict: Dictionary containing the inputs for the model, so that self.model.forward(**inputs) can be called.
"""
pass
@abstractmethod
def get_logits(
self,
images: Tensor | Image,
prompt: str,
layer_wise: bool = False,
) -> Tensor:
"""Get logits from the model.
Args:
images (Tensor | Image): Image tensor or PIL Image.
prompt (str): Prompt string.
layer_wise (bool, optional): Whether to return layer-wise logits. Defaults to False.
Returns:
Tensor: Logits from the model, shape (vocab_size, batch, Optional[num_layers]).
"""
pass