Skip to content

Commit

Permalink
Berke/llama index tests (#5512)
Browse files Browse the repository at this point in the history
Add Llamaindex tests.

GitOrigin-RevId: bd952d14946508ba31922ae37762d1224527783e
  • Loading branch information
berkecanrizai authored and Manul from Pathway committed Feb 2, 2024
1 parent eaf8d1e commit 1ec6b44
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/package_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jobs:
source "${ENV_NAME}/bin/activate"
WHEEL=(target/wheels/pathway-*.whl)
pip install --prefer-binary "${WHEEL}[tests]"
pip install llama_index
# --confcutdir anything below to avoid picking REPO_TOP_DIR/conftest.py
if [[ "$RUNNER_NAME" == *mac* ]]; then
export PYTEST_XDIST_AUTO_NUM_WORKERS=4
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ jobs:
source "${ENV_NAME}/bin/activate"
WHEEL=(./wheels/pathway-*.whl)
pip install --prefer-binary "${WHEEL}[tests]"
pip install llama_index
# --confcutdir anything below to avoid picking REPO_TOP_DIR/conftest.py
export PYTEST_ADDOPTS="--dist worksteal -n auto"
python -m pytest -v --confcutdir "${ENV_NAME}" --doctest-modules --pyargs pathway
Expand Down Expand Up @@ -217,8 +218,10 @@ jobs:
python"${{ matrix.python-version }}" -m venv "${ENV_NAME}"
source "${ENV_NAME}/bin/activate"
WHEEL=(./wheels/pathway-*.whl)
PATHWAY_MONITORING_HTTP_PORT=20099
export PATHWAY_MONITORING_HTTP_PORT=20099
export LLAMA_READER_PORT=8799
pip install --prefer-binary "${WHEEL}[tests]"
pip install llama_index
# --confcutdir anything below to avoid picking REPO_TOP_DIR/conftest.py
if [[ "$RUNNER_NAME" == *mac* ]]; then
export PYTEST_XDIST_AUTO_NUM_WORKERS=2
Expand Down
1 change: 1 addition & 0 deletions python/pathway/xpacks/llm/tests/example_text.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello world.,This is a test.
182 changes: 182 additions & 0 deletions python/pathway/xpacks/llm/tests/test_llamaindex_vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import time
from multiprocessing import Process
from typing import List

import requests
from llama_index.embeddings.base import BaseEmbedding
from llama_index.node_parser import TextSplitter
from llama_index.retrievers import PathwayVectorServer

import pathway as pw
from pathway.tests.utils import xfail_on_multiple_threads

PATHWAY_HOST = "127.0.0.1"

ENV_PORT = os.environ.get("LLAMA_READER_PORT")
PATHWAY_PORT = 7769

if ENV_PORT:
PATHWAY_PORT = int(ENV_PORT)

EXAMPLE_TEXT_FILE = "example_text.md"


def get_data_sources():
test_dir = os.path.dirname(os.path.abspath(__file__))
example_text_path = os.path.join(test_dir, EXAMPLE_TEXT_FILE)

data_sources = []
data_sources.append(
pw.io.fs.read(
example_text_path,
format="binary",
mode="streaming",
with_metadata=True,
)
)
return data_sources


def mock_get_text_embedding(text: str) -> List[float]:
"""Mock get text embedding."""
print("Embedder embedding:", text)
if text == "Hello world.":
return [1, 0, 0, 0, 0]
elif text == "This is a test.":
return [0, 1, 0, 0, 0]
elif text == "This is another test.":
return [0, 0, 1, 0, 0]
elif text == "This is a test v2.":
return [0, 0, 0, 1, 0]
elif text == "This is a test v3.":
return [0, 0, 0, 0, 1]
elif text == "This is bar test.":
return [0, 0, 1, 0, 0]
elif text == "Hello world backup.":
return [0, 0, 0, 0, 1]
else:
return [0, 0, 0, 0, 0]


class NewlineTextSplitter(TextSplitter):
def split_text(self, text: str) -> List[str]:
return text.split(",")


class FakeEmbedding(BaseEmbedding):
def _get_text_embedding(self, text: str) -> List[float]:
return mock_get_text_embedding(text)

def _get_query_embedding(self, query: str) -> List[float]:
return mock_get_text_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
return mock_get_text_embedding(query)


def pathway_server(port):
data_sources = get_data_sources()

embed_model = FakeEmbedding()

custom_transformations = [
NewlineTextSplitter(),
embed_model,
]

processing_pipeline = PathwayVectorServer(
*data_sources,
transformations=custom_transformations, # type: ignore
)

thread = processing_pipeline.run_server(
host=PATHWAY_HOST,
port=port,
threaded=True,
with_cache=False,
)
thread.join()


@xfail_on_multiple_threads
def test_llama_retriever():
port = PATHWAY_PORT
p = Process(target=pathway_server, args=[port])
p.start()

time.sleep(15)

from llama_index.retrievers import PathwayRetriever

retriever = PathwayRetriever(host=PATHWAY_HOST, port=port)

MAX_ATTEMPTS = 8
attempts = 0
results = []
while attempts < MAX_ATTEMPTS:
try:
results = retriever.retrieve(str_or_query_bundle="Hello world.")
except requests.exceptions.RequestException:
pass
else:
break
time.sleep(1)
attempts += 1

assert len(results) == 1
assert results[0].text == "Hello world."
assert results[0].score == 1.0

p.terminate()


@xfail_on_multiple_threads
def test_llama_reader():
port = PATHWAY_PORT + 1
p = Process(target=pathway_server, args=[port])
p.start()

time.sleep(20)

from llama_index.readers import PathwayReader

pr = PathwayReader(host=PATHWAY_HOST, port=port)

MAX_ATTEMPTS = 8
attempts = 0
results = []
while attempts < MAX_ATTEMPTS:
try:
results = pr.load_data("Hello world.", k=1)
except requests.exceptions.RequestException:
pass
else:
break
time.sleep(1)
attempts += 1

assert len(results) == 1

first_result = results[0]
assert first_result.text == "Hello world."
assert EXAMPLE_TEXT_FILE in first_result.metadata["path"]

attempts = 0
results = []
while attempts < MAX_ATTEMPTS:
try:
results = pr.load_data("This is a test.", k=1)
except requests.exceptions.RequestException:
pass
else:
break
time.sleep(1)
attempts += 1

time.sleep(1)
first_result = results[0]
assert first_result.text == "This is a test."
assert EXAMPLE_TEXT_FILE in first_result.metadata["path"]

p.terminate()

0 comments on commit 1ec6b44

Please sign in to comment.