Skip to content

Commit

Permalink
add embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiguoxu committed Jul 9, 2024
1 parent 8fc57e7 commit dce6918
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 0 deletions.
Empty file added core/rag/__init__.py
Empty file.
Empty file added core/rag/embedding/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions core/rag/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import abstractmethod, ABC
from typing import List


class Embedding(ABC):
@abstractmethod
def embed_query(self, text: str) -> List[float]:
...

@abstractmethod
def embed_documents(self, text_list: List[str]) -> List[List[float]]:
...
Empty file.
102 changes: 102 additions & 0 deletions core/rag/embedding/huggingface/hf_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List, Dict
from typing import Any
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

from core.rag.embedding.embedding import Embedding
from core.rag.utils import infer_torch_device, to_list, OneOrMany
from core.utils.utils import filter_kwargs_by_pydantic, filter_kwargs_by_method

DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_BATCH_SIZE = 32
DEFAULT_HUGGINGFACE_LENGTH = 512
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving supporting documents: "
DEFAULT_QUERY_BGE_INSTRUCTION_EN = "Represent this question for searching relevant passages: "
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"


class HuggingfaceEmbedding(BaseModel, Embedding):
model: str = DEFAULT_MODEL_NAME
client: SentenceTransformer
# Encode parameters.
query_instruction: str | None = None
document_instruction: str | None = None
batch_size: int = DEFAULT_BATCH_SIZE
normalize: bool = Field(default=True, description="Normalize embeddings or not.")
show_progress: bool = False
multi_process: bool = False # Run encode() on multiple GPUs.
extra_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)

def __init__(self,
model: str = DEFAULT_MODEL_NAME,
# Encode parameters.
query_instruction: str | None = None,
document_instruction: str | None = None,
normalize: bool = True,
batch_size: int = DEFAULT_BATCH_SIZE,
show_progress: bool = False,
# SentenceTransformer init parameters.
device: str | None = None,
trust_remote_code: bool = False,
cache_folder: str | None = None,
max_length: int | None = None,
**model_or_encode_kwargs: Any):
# Init instruction.
if model.startswith("BAAI/bge-"):
if "-zh" in model.lower():
query_instruction = query_instruction or DEFAULT_QUERY_BGE_INSTRUCTION_ZH
else:
query_instruction = query_instruction or DEFAULT_QUERY_BGE_INSTRUCTION_EN
elif "instructor" in model.lower():
query_instruction = query_instruction or DEFAULT_QUERY_INSTRUCTION

if "instructor" in model.lower():
document_instruction = document_instruction or DEFAULT_QUERY_INSTRUCTION

device = device or infer_torch_device()

# Collect model init kwargs.
model_init_kwargs = filter_kwargs_by_method(SentenceTransformer.__init__,
{**locals(), **model_or_encode_kwargs},
exclude_none=True)
client = SentenceTransformer(model, **model_init_kwargs)
if max_length:
client.max_seq_length = max_length

extra_encode_kwargs = filter_kwargs_by_method(client.encode, model_or_encode_kwargs)
kwargs = filter_kwargs_by_pydantic(HuggingfaceEmbedding, locals(), exclude_none=True)
super().__init__(**kwargs)

def embed_query(self, text: str) -> List[float]:
return self._embed(text, prompt=self.query_instruction)[0]

def embed_documents(self, text_list: List[str]) -> List[List[float]]:
return self._embed(text_list, prompt=self.document_instruction)

def _embed(self, sentences: OneOrMany[str], prompt: str | None = None) -> List[List[float]]:
sentences = [s.replace("\n", " ") for s in to_list(sentences)]
if self.multi_process:
import sentence_transformers
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(sentences, pool, prompt=prompt, batch_size=self.batch_size)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(sentences,
prompt=prompt,
batch_size=self.batch_size,
normalize_embeddings=self.normalize,
show_progress_bar=self.show_progress,
**self.extra_encode_kwargs) # type: ignore[assignment]
return embeddings.tolist()

@property
def max_length(self):
return self.client.max_seq_length

@max_length.setter
def max_length(self, value: int):
self.client.max_seq_length = value

class Config:
arbitrary_types_allowed = True
20 changes: 20 additions & 0 deletions core/rag/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import TypeVar, Sequence, Any, Union


def infer_torch_device() -> str:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"


T = TypeVar("T")
OneOrMany = Union[T | Sequence[T]]


def to_list(obj: Any) -> list:
if isinstance(obj, str):
return [obj]
return list(obj) if isinstance(obj, Sequence) else [obj]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
48 changes: 48 additions & 0 deletions tutorials/rag/embedding.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from core.rag.embedding.huggingface.hf_embedding import HuggingfaceEmbedding\n",
"\n",
"embedding = HuggingfaceEmbedding(\"hkunlp/instructor-base\")\n",
"embedding.embed_query(\"你好\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit dce6918

Please sign in to comment.