Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add graphrag modes #574

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@

if USE_NANO_GRAPHRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex")
elif USE_LIGHTRAG:
if USE_LIGHTRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.LightRAGIndex")

KH_INDEX_TYPES = [
Expand Down
11 changes: 11 additions & 0 deletions libs/ktem/ktem/assets/css/main.css
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,24 @@ mark {
right: 15px;
}

/* prevent overflow of html info panel */
#html-info-panel {
overflow-x: auto !important;
}

#chat-expand-button {
position: absolute;
top: 6px;
right: -10px;
z-index: 1;
}

#save-setting-btn {
width: 150px;
height: 30px;
min-width: 100px !important;
}

#quick-setting-labels {
margin-top: 5px;
margin-bottom: -10px;
Expand Down
5 changes: 5 additions & 0 deletions libs/ktem/ktem/assets/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ function run() {
let chat_column = document.getElementById("main-chat-bot");
let conv_column = document.getElementById("conv-settings-panel");

// move setting close button
let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav");
let setting_close_button = document.getElementById("save-setting-btn");
setting_tab_nav_bar.appendChild(setting_close_button);

let default_conv_column_min_width = "min(300px, 100%)";
conv_column.style.minWidth = default_conv_column_min_width

Expand Down
20 changes: 19 additions & 1 deletion libs/ktem/ktem/index/file/graph/light_graph_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline

Expand All @@ -12,14 +12,32 @@ def _setup_indexing_cls(self):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]

def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline

def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

retrievers = [
LightRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

Expand Down
100 changes: 82 additions & 18 deletions libs/ktem/ktem/index/file/graph/lightrag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def llm_func(
if if_cache_return is not None:
return if_cache_return["return"]

output = model(input_messages).text
output = (await model.ainvoke(input_messages)).text

print("-" * 50)
print(output, "\n", "-" * 50)
Expand Down Expand Up @@ -220,7 +220,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""

prompts: dict[str, str] = {}

@classmethod
def get_user_settings(cls) -> dict:
try:
from lightrag.prompt import PROMPTS

blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
}
except ImportError as e:
print(e)
return {}

def call_graphrag_index(self, graph_id: str, docs: list[Document]):
from lightrag.prompt import PROMPTS

# modify the prompt if it is set in the settings
for prompt_name, content in self.prompts.items():
if prompt_name in PROMPTS:
PROMPTS[prompt_name] = content

_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -302,6 +332,19 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):

Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
search_type: str = "local"

@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global", "hybrid"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}

def _build_graph_search(self):
file_id = self.file_ids[0]
Expand All @@ -326,7 +369,8 @@ def _build_graph_search(self):
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
print("search_type", self.search_type)
query_params = QueryParam(mode=self.search_type, only_need_context=True)

return graphrag_func, query_params

Expand Down Expand Up @@ -381,20 +425,40 @@ def run(
return []

graphrag_func, query_params = self._build_graph_search()
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)

documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)

return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
# only local mode support graph visualization
if query_params.mode == "local":
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)
documents += [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
else:
context = graphrag_func.query(text, query_params)

# account for missing ``` for closing code block
context += "\n```"

documents = [
RetrievedDocument(
text=context,
metadata={
"file_name": "GraphRAG {} Search".format(
query_params.mode.capitalize()
),
"type": "table",
},
)
]

return documents
20 changes: 19 additions & 1 deletion libs/ktem/ktem/index/file/graph/nano_graph_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline

Expand All @@ -12,14 +12,32 @@ def _setup_indexing_cls(self):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]

def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline

def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

retrievers = [
NanoGraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

Expand Down
Loading
Loading