-
Notifications
You must be signed in to change notification settings - Fork 284
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Llamaindex tests. GitOrigin-RevId: bd952d14946508ba31922ae37762d1224527783e
- Loading branch information
1 parent
eaf8d1e
commit 1ec6b44
Showing
4 changed files
with
188 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Hello world.,This is a test. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import os | ||
import time | ||
from multiprocessing import Process | ||
from typing import List | ||
|
||
import requests | ||
from llama_index.embeddings.base import BaseEmbedding | ||
from llama_index.node_parser import TextSplitter | ||
from llama_index.retrievers import PathwayVectorServer | ||
|
||
import pathway as pw | ||
from pathway.tests.utils import xfail_on_multiple_threads | ||
|
||
PATHWAY_HOST = "127.0.0.1" | ||
|
||
ENV_PORT = os.environ.get("LLAMA_READER_PORT") | ||
PATHWAY_PORT = 7769 | ||
|
||
if ENV_PORT: | ||
PATHWAY_PORT = int(ENV_PORT) | ||
|
||
EXAMPLE_TEXT_FILE = "example_text.md" | ||
|
||
|
||
def get_data_sources(): | ||
test_dir = os.path.dirname(os.path.abspath(__file__)) | ||
example_text_path = os.path.join(test_dir, EXAMPLE_TEXT_FILE) | ||
|
||
data_sources = [] | ||
data_sources.append( | ||
pw.io.fs.read( | ||
example_text_path, | ||
format="binary", | ||
mode="streaming", | ||
with_metadata=True, | ||
) | ||
) | ||
return data_sources | ||
|
||
|
||
def mock_get_text_embedding(text: str) -> List[float]: | ||
"""Mock get text embedding.""" | ||
print("Embedder embedding:", text) | ||
if text == "Hello world.": | ||
return [1, 0, 0, 0, 0] | ||
elif text == "This is a test.": | ||
return [0, 1, 0, 0, 0] | ||
elif text == "This is another test.": | ||
return [0, 0, 1, 0, 0] | ||
elif text == "This is a test v2.": | ||
return [0, 0, 0, 1, 0] | ||
elif text == "This is a test v3.": | ||
return [0, 0, 0, 0, 1] | ||
elif text == "This is bar test.": | ||
return [0, 0, 1, 0, 0] | ||
elif text == "Hello world backup.": | ||
return [0, 0, 0, 0, 1] | ||
else: | ||
return [0, 0, 0, 0, 0] | ||
|
||
|
||
class NewlineTextSplitter(TextSplitter): | ||
def split_text(self, text: str) -> List[str]: | ||
return text.split(",") | ||
|
||
|
||
class FakeEmbedding(BaseEmbedding): | ||
def _get_text_embedding(self, text: str) -> List[float]: | ||
return mock_get_text_embedding(text) | ||
|
||
def _get_query_embedding(self, query: str) -> List[float]: | ||
return mock_get_text_embedding(query) | ||
|
||
async def _aget_query_embedding(self, query: str) -> List[float]: | ||
return mock_get_text_embedding(query) | ||
|
||
|
||
def pathway_server(port): | ||
data_sources = get_data_sources() | ||
|
||
embed_model = FakeEmbedding() | ||
|
||
custom_transformations = [ | ||
NewlineTextSplitter(), | ||
embed_model, | ||
] | ||
|
||
processing_pipeline = PathwayVectorServer( | ||
*data_sources, | ||
transformations=custom_transformations, # type: ignore | ||
) | ||
|
||
thread = processing_pipeline.run_server( | ||
host=PATHWAY_HOST, | ||
port=port, | ||
threaded=True, | ||
with_cache=False, | ||
) | ||
thread.join() | ||
|
||
|
||
@xfail_on_multiple_threads | ||
def test_llama_retriever(): | ||
port = PATHWAY_PORT | ||
p = Process(target=pathway_server, args=[port]) | ||
p.start() | ||
|
||
time.sleep(15) | ||
|
||
from llama_index.retrievers import PathwayRetriever | ||
|
||
retriever = PathwayRetriever(host=PATHWAY_HOST, port=port) | ||
|
||
MAX_ATTEMPTS = 8 | ||
attempts = 0 | ||
results = [] | ||
while attempts < MAX_ATTEMPTS: | ||
try: | ||
results = retriever.retrieve(str_or_query_bundle="Hello world.") | ||
except requests.exceptions.RequestException: | ||
pass | ||
else: | ||
break | ||
time.sleep(1) | ||
attempts += 1 | ||
|
||
assert len(results) == 1 | ||
assert results[0].text == "Hello world." | ||
assert results[0].score == 1.0 | ||
|
||
p.terminate() | ||
|
||
|
||
@xfail_on_multiple_threads | ||
def test_llama_reader(): | ||
port = PATHWAY_PORT + 1 | ||
p = Process(target=pathway_server, args=[port]) | ||
p.start() | ||
|
||
time.sleep(20) | ||
|
||
from llama_index.readers import PathwayReader | ||
|
||
pr = PathwayReader(host=PATHWAY_HOST, port=port) | ||
|
||
MAX_ATTEMPTS = 8 | ||
attempts = 0 | ||
results = [] | ||
while attempts < MAX_ATTEMPTS: | ||
try: | ||
results = pr.load_data("Hello world.", k=1) | ||
except requests.exceptions.RequestException: | ||
pass | ||
else: | ||
break | ||
time.sleep(1) | ||
attempts += 1 | ||
|
||
assert len(results) == 1 | ||
|
||
first_result = results[0] | ||
assert first_result.text == "Hello world." | ||
assert EXAMPLE_TEXT_FILE in first_result.metadata["path"] | ||
|
||
attempts = 0 | ||
results = [] | ||
while attempts < MAX_ATTEMPTS: | ||
try: | ||
results = pr.load_data("This is a test.", k=1) | ||
except requests.exceptions.RequestException: | ||
pass | ||
else: | ||
break | ||
time.sleep(1) | ||
attempts += 1 | ||
|
||
time.sleep(1) | ||
first_result = results[0] | ||
assert first_result.text == "This is a test." | ||
assert EXAMPLE_TEXT_FILE in first_result.metadata["path"] | ||
|
||
p.terminate() |