diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5a313d7bb6..b5111b75ca 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -11,9 +11,13 @@ BM25RetrieverConfig, ChromaIndexConfig, ChromaRetrieverConfig, + ElasticsearchIndexConfig, + ElasticsearchRetrieverConfig, + ElasticsearchStoreConfig, FAISSRetrieverConfig, LLMRankerConfig, ) +from metagpt.utils.exceptions import handle_exception DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" QUESTION = "What are key qualities to be a good writer?" @@ -39,12 +43,22 @@ def rag_key(self) -> str: class RAGExample: """Show how to use RAG.""" - def __init__(self): - self.engine = SimpleEngine.from_docs( - input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], - ranker_configs=[LLMRankerConfig()], - ) + def __init__(self, engine: SimpleEngine = None): + self._engine = engine + + @property + def engine(self): + if not self._engine: + self._engine = SimpleEngine.from_docs( + input_files=[DOC_PATH], + retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + return self._engine + + @engine.setter + def engine(self, value: SimpleEngine): + self._engine = value async def run_pipeline(self, question=QUESTION, print_title=True): """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: @@ -97,6 +111,7 @@ async def add_docs(self): self.engine.add_docs([travel_filepath]) await self.run_pipeline(question=travel_question, print_title=False) + @handle_exception async def add_objects(self, print_title=True): """This example show how to add objects. @@ -154,20 +169,41 @@ async def init_and_query_chromadb(self): """ self._print_title("Init And Query ChromaDB") - # save index + # 1. save index output_dir = DATA_PATH / "rag" SimpleEngine.from_docs( input_files=[TRAVEL_DOC_PATH], retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)], ) - # load index - engine = SimpleEngine.from_index( - index_config=ChromaIndexConfig(persist_path=output_dir), + # 2. load index + engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir)) + + # 3. query + answer = await engine.aquery(TRAVEL_QUESTION) + self._print_query_result(answer) + + @handle_exception + async def init_and_query_es(self): + """This example show how to use es. how to save and load index. will print something like: + + Query Result: + Bob likes traveling. + """ + self._print_title("Init And Query Elasticsearch") + + # 1. create es index and save docs + store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200") + engine = SimpleEngine.from_docs( + input_files=[TRAVEL_DOC_PATH], + retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)], ) - # query - answer = engine.query(TRAVEL_QUESTION) + # 2. load index + engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config)) + + # 3. query + answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) @staticmethod @@ -205,6 +241,7 @@ async def main(): await e.add_objects() await e.init_objects() await e.init_and_query_chromadb() + await e.init_and_query_es() if __name__ == "__main__": diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py index 373181384d..93699db884 100644 --- a/metagpt/rag/engines/__init__.py +++ b/metagpt/rag/engines/__init__.py @@ -1,5 +1,6 @@ """Engines init""" from metagpt.rag.engines.simple import SimpleEngine +from metagpt.rag.engines.flare import FLAREEngine -__all__ = ["SimpleEngine"] +__all__ = ["SimpleEngine", "FLAREEngine"] diff --git a/metagpt/rag/engines/flare.py b/metagpt/rag/engines/flare.py new file mode 100644 index 0000000000..dc05bd3dde --- /dev/null +++ b/metagpt/rag/engines/flare.py @@ -0,0 +1,9 @@ +"""FLARE Engine. + +Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters. +For example, Create a simple engine, and then pass it to FLAREEngine. +""" + +from llama_index.core.query_engine import ( # noqa: F401 + FLAREInstructQueryEngine as FLAREEngine, +) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 02f9ca7b1b..5c58103089 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -130,10 +130,12 @@ def from_objs( retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. """ + objs = objs or [] + retriever_configs = retriever_configs or [] + if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") - objs = objs or [] nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] index = VectorStoreIndex( nodes=nodes, diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index 8f8155914d..fbdfbf1a81 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -41,7 +41,7 @@ def get_instance(self, key: Any, **kwargs) -> Any: if creator: return creator(key, **kwargs) - raise ValueError(f"Unknown config: {key}") + raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 6aad695e74..f200fc94f0 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -4,6 +4,8 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory @@ -11,6 +13,8 @@ BaseIndexConfig, BM25IndexConfig, ChromaIndexConfig, + ElasticsearchIndexConfig, + ElasticsearchKeywordIndexConfig, FAISSIndexConfig, ) from metagpt.rag.vector_stores.chroma import ChromaVectorStore @@ -22,6 +26,8 @@ def __init__(self): FAISSIndexConfig: self._create_faiss, ChromaIndexConfig: self._create_chroma, BM25IndexConfig: self._create_bm25, + ElasticsearchIndexConfig: self._create_es, + ElasticsearchKeywordIndexConfig: self._create_es, } super().__init__(creators) @@ -30,31 +36,44 @@ def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex: return super().get_instance(config, **kwargs) def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) - vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index - def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) + def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) + + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) + + def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - index = VectorStoreIndex.from_vector_store( - vector_store, - embed_model=embed_model, - ) - return index - def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _index_from_storage( + self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: embed_model = self._extract_embed_model(config, **kwargs) - storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index + return load_index_from_storage(storage_context=storage_context, embed_model=embed_model) + + def _index_from_vector_store( + self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: + embed_model = self._extract_embed_model(config, **kwargs) + + return VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embed_model, + ) def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: return self._val_from_config_or_kwargs("embed_model", config, **kwargs) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 1cdbab14dc..17c499b766 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -33,7 +33,9 @@ class RAGLLM(CustomLLM): @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" - return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name) + return LLMMetadata( + context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown" + ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index f05599e159..07cb1b929f 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,9 +3,16 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig +from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor +from metagpt.rag.schema import ( + BaseRankerConfig, + ColbertRerankConfig, + LLMRankerConfig, + ObjectRankerConfig, +) class RankerFactory(ConfigBasedFactory): @@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory): def __init__(self): creators = { LLMRankerConfig: self._create_llm_ranker, + ColbertRerankConfig: self._create_colbert_ranker, + ObjectRankerConfig: self._create_object_ranker, } super().__init__(creators) @@ -28,6 +37,12 @@ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: config.llm = self._extract_llm(config, **kwargs) return LLMRerank(**config.model_dump()) + def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: + return ColbertRerank(**config.model_dump()) + + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: + return ObjectSortPostprocessor(**config.model_dump()) + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index ba48c753e8..a107d95733 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -6,18 +6,22 @@ import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever +from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ChromaRetrieverConfig, + ElasticsearchKeywordRetrieverConfig, + ElasticsearchRetrieverConfig, FAISSRetrieverConfig, IndexRetrieverConfig, ) @@ -32,6 +36,8 @@ def __init__(self): FAISSRetrieverConfig: self._create_faiss_retriever, BM25RetrieverConfig: self._create_bm25_retriever, ChromaRetrieverConfig: self._create_chroma_retriever, + ElasticsearchRetrieverConfig: self._create_es_retriever, + ElasticsearchKeywordRetrieverConfig: self._create_es_retriever, } super().__init__(creators) @@ -53,20 +59,29 @@ def _create_default(self, **kwargs) -> RAGRetriever: def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: config.index = copy.deepcopy(self._extract_index(config, **kwargs)) - nodes = list(config.index.docstore.docs.values()) - return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) + + return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: db = chromadb.PersistentClient(path=str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return ChromaRetriever(**config.model_dump()) + def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + + return ElasticsearchRetriever(**config.model_dump()) + def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) diff --git a/metagpt/rag/rankers/object_ranker.py b/metagpt/rag/rankers/object_ranker.py new file mode 100644 index 0000000000..b8456803f6 --- /dev/null +++ b/metagpt/rag/rankers/object_ranker.py @@ -0,0 +1,55 @@ +"""Object ranker.""" + +import heapq +import json +from typing import Literal, Optional + +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import Field + +from metagpt.rag.schema import ObjectNode + + +class ObjectSortPostprocessor(BaseNodePostprocessor): + """Sorted by object's field, desc or asc. + + Assumes nodes is list of ObjectNode with score. + """ + + field_name: str = Field(..., description="field name of the object, field's value must can be compared.") + order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") + top_n: int = 5 + + @classmethod + def class_name(cls) -> str: + return "ObjectSortPostprocessor" + + def _postprocess_nodes( + self, + nodes: list[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> list[NodeWithScore]: + """Postprocess nodes.""" + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + + if not nodes: + return [] + + self._check_metadata(nodes[0].node) + + sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name] + return self._get_sort_func()(self.top_n, nodes, key=sort_key) + + def _check_metadata(self, node: ObjectNode): + try: + obj_dict = json.loads(node.metadata.get("obj_json")) + except Exception as e: + raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}") + + if self.field_name not in obj_dict: + raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}") + + def _get_sort_func(self): + return heapq.nlargest if self.order == "desc" else heapq.nsmallest diff --git a/metagpt/rag/retrievers/es_retriever.py b/metagpt/rag/retrievers/es_retriever.py new file mode 100644 index 0000000000..a1a0a6138d --- /dev/null +++ b/metagpt/rag/retrievers/es_retriever.py @@ -0,0 +1,17 @@ +"""Elasticsearch retriever.""" + +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode + + +class ElasticsearchRetriever(VectorIndexRetriever): + """Elasticsearch retriever.""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """Support add nodes.""" + self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist. + + Elasticsearch automatically saves, so there is no need to implement.""" diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index 7e543cce2c..80b4092923 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever): """FAISS retriever.""" def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: - """Support add nodes""" + """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index cae1c2979f..183f6e0c76 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,11 +1,12 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Union +from typing import Any, Literal, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from metagpt.rag.interface import RAGObject @@ -46,6 +47,35 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): collection_name: str = Field(default="metagpt", description="The name of the collection.") +class ElasticsearchStoreConfig(BaseModel): + index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") + es_url: str = Field(default=None, description="Elasticsearch URL.") + es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") + es_api_key: str = Field(default=None, description="Elasticsearch API key.") + es_user: str = Field(default=None, description="Elasticsearch username.") + es_password: str = Field(default=None, description="Elasticsearch password.") + batch_size: int = Field(default=200, description="Batch size for bulk indexing.") + distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") + + +class ElasticsearchRetrieverConfig(IndexRetrieverConfig): + """Config for Elasticsearch-based retrievers. Support both vector and text.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + vector_store_query_mode: VectorStoreQueryMode = Field( + default=VectorStoreQueryMode.DEFAULT, description="default is vector query." + ) + + +class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): + """Config for Elasticsearch-based retrievers. Support text only.""" + + _no_embedding: bool = PrivateAttr(default=True) + vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field( + default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only." + ) + + class BaseRankerConfig(BaseModel): """Common config for rankers. @@ -53,7 +83,6 @@ class BaseRankerConfig(BaseModel): """ model_config = ConfigDict(arbitrary_types_allowed=True) - top_n: int = Field(default=5, description="The number of top results to return.") @@ -66,12 +95,24 @@ class LLMRankerConfig(BaseRankerConfig): ) +class ColbertRerankConfig(BaseRankerConfig): + model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") + device: str = Field(default="cpu", description="Device to use for sentence transformer.") + keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") + + +class ObjectRankerConfig(BaseRankerConfig): + field_name: str = Field(..., description="field name of the object, field's value must can be compared.") + order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") + + class BaseIndexConfig(BaseModel): """Common config for index. If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. """ + model_config = ConfigDict(arbitrary_types_allowed=True) persist_path: Union[str, Path] = Field(description="The directory of saved data.") @@ -97,6 +138,19 @@ class BM25IndexConfig(BaseIndexConfig): _no_embedding: bool = PrivateAttr(default=True) +class ElasticsearchIndexConfig(VectorIndexConfig): + """Config for es-based index.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + persist_path: Union[str, Path] = "" + + +class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): + """Config for es-based index. no embedding.""" + + _no_embedding: bool = PrivateAttr(default=True) + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" diff --git a/setup.py b/setup.py index f834b4c449..c728872ef2 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ def run(self): "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", + "llama-index-vector-stores-elasticsearch==0.1.6", + "llama-index-postprocessor-colbert-rerank==0.1.1", "chromadb==0.4.23", ], } diff --git a/tests/metagpt/rag/rankers/test_object_ranker.py b/tests/metagpt/rag/rankers/test_object_ranker.py new file mode 100644 index 0000000000..7ea6b7488b --- /dev/null +++ b/tests/metagpt/rag/rankers/test_object_ranker.py @@ -0,0 +1,60 @@ +import json + +import pytest +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import BaseModel + +from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor +from metagpt.rag.schema import ObjectNode + + +class Record(BaseModel): + score: int + + +class TestObjectSortPostprocessor: + @pytest.fixture + def nodes_with_scores(self): + nodes = [ + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10), + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20), + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5), + ] + return nodes + + @pytest.fixture + def query_bundle(self, mocker): + return mocker.MagicMock(spec=QueryBundle) + + def test_sort_descending(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert [node.score for node in sorted_nodes] == [20, 10, 5] + + def test_sort_ascending(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="asc") + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert [node.score for node in sorted_nodes] == [5, 10, 20] + + def test_top_n_limit(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2) + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert len(sorted_nodes) == 2 + assert [node.score for node in sorted_nodes] == [20, 10] + + def test_invalid_json_metadata(self, query_bundle): + nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)] + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes, query_bundle) + + def test_missing_query_bundle(self, nodes_with_scores): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None) + + def test_field_not_found_in_object(self): + nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)] + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes)