Skip to content

Commit

Permalink
New workflow to generate embeddings in a single workflow (#1296)
Browse files Browse the repository at this point in the history
* New workflow to generate embeddings in a single workflow

* New workflow to generate embeddings in a single workflow

* version change

* clean tests without any embeddings references

* clean tests without any embeddings references

* remove code

* feedback implemented

* changes in logic

* feedback implemented

* store in table bug fixed

* smoke test for generate_text_embeddings workflow

* smoke test fix

* add generate_text_embeddings to the list of transient workflows

* smoke tests

* fix

* ruff formatting updates

* fix

* smoke test fixed

* smoke test fixed

* fix lancedb import

* smoke test fix

* ignore sorting

* smoke test fixed

* smoke test fixed

* check smoke test

* smoke test fixed

* change config for vector store

* format fix

* vector store changes

* revert debug profile back to empty filepath

* merge conflict solved

* merge conflict solved

* format fixed

* format fixed

* fix return dataframe

* snapshot fix

* format fix

* embeddings param implemented

* validation fixes

* fix map

* fix map

* fix properties

* config updates

* smoke test fixed

* settings change

* Update collection config and rework back-compat

* Repalce . with - for embedding store

---------

Co-authored-by: Alonso Guevara <[email protected]>
Co-authored-by: Josh Bradley <[email protected]>
Co-authored-by: Nathan Evans <[email protected]>
  • Loading branch information
4 people authored Nov 1, 2024
1 parent 8302920 commit 17658c5
Show file tree
Hide file tree
Showing 51 changed files with 693 additions and 804 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241018204541069382.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "embeddings moved to a different workflow"
}
6 changes: 3 additions & 3 deletions docs/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ This is the base LLM configuration section. Other steps may override this config
- `async_mode` (see Async Mode top-level config)
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip.
- `target` **required|all|none** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip. Only useful if target=all to customize the list.
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
- `url` **str** (only for AI Search) - AI Search endpoint
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
- `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default`
- `strategy` **dict** - Fully override the text-embedding strategy.

## chunks
Expand Down
2 changes: 1 addition & 1 deletion docs/examples_notebooks/local_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
27 changes: 20 additions & 7 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
config = _patch_vector_config(config)

pipeline_config = create_pipeline_config(config)
pipeline_cache = (
Expand All @@ -90,3 +84,22 @@ async def build_index(
progress_reporter.success(output.workflow)
progress_reporter.info(str(output.result))
return outputs


def _patch_vector_config(config: GraphRagConfig):
"""Back-compat patch to ensure a default vector store configuration."""
if not config.embeddings.vector_store:
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": "output/lancedb",
"container_name": "default",
"overwrite": True,
}
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
return config
136 changes: 55 additions & 81 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,56 +182,22 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)

# TODO: update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore

reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
Expand Down Expand Up @@ -289,56 +255,22 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)

# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore

reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

description_embedding_store = _get_embedding_description_store(
conf_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
Expand Down Expand Up @@ -368,13 +300,55 @@ async def local_search_streaming(
yield stream_chunk


def _patch_vector_store(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_level: int,
) -> GraphRagConfig:
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
if not config.embeddings.vector_store:
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"container_name": "default",
"overwrite": True,
}
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name="default-entity-description",
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
_entities = read_indexer_entities(nodes, entities, community_level)
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
return config


def _get_embedding_description_store(
config_args: dict,
):
"""Get the embedding description store."""
vector_store_type = config_args["type"]
collection_name = f"{config_args['container_name']}-entity-description"
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
vector_store_type=vector_store_type,
kwargs={**config_args, "collection_name": collection_name},
)
description_embedding_store.connect(**config_args)
return description_embedding_store
Expand Down
5 changes: 4 additions & 1 deletion graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def run_local_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)

# TODO remove optional create_final_entities_description_embeddings.parquet to delete backwards compatibility
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
Expand All @@ -125,7 +126,9 @@ def run_local_search(
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
optional_list=[
"create_final_covariates.parquet",
],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def hydrate_parallelization_params(
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
)
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
umap_model = UmapConfig(
Expand Down
3 changes: 2 additions & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
SNAPSHOTS_EMBEDDINGS = False
STORAGE_BASE_DIR = "output"
STORAGE_TYPE = StorageType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
Expand All @@ -91,7 +92,7 @@
VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value}
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: entity_description_embeddings
collection_name: default
overwrite: true\
"""

Expand Down
1 change: 1 addition & 0 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum):

all = "all"
required = "required"
none = "none"

def __repr__(self):
"""Get a string representation."""
Expand Down
4 changes: 4 additions & 0 deletions graphrag/config/models/snapshots_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ class SnapshotsConfig(BaseModel):
description="A flag indicating whether to take snapshots of top-level nodes.",
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
)
embeddings: bool = Field(
description="A flag indicating whether to take snapshots of embeddings.",
default=defs.SNAPSHOTS_EMBEDDINGS,
)
22 changes: 22 additions & 0 deletions graphrag/index/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from .embeddings import (
all_embeddings,
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
document_raw_content_embedding,
entity_description_embedding,
entity_name_embedding,
relationship_description_embedding,
required_embeddings,
text_unit_text_embedding,
)
from .input import (
PipelineCSVInputConfig,
PipelineInputConfig,
Expand Down Expand Up @@ -66,4 +78,14 @@
"PipelineWorkflowConfig",
"PipelineWorkflowReference",
"PipelineWorkflowStep",
"all_embeddings",
"community_full_content_embedding",
"community_summary_embedding",
"community_title_embedding",
"document_raw_content_embedding",
"entity_description_embedding",
"entity_name_embedding",
"relationship_description_embedding",
"required_embeddings",
"text_unit_text_embedding",
]
25 changes: 25 additions & 0 deletions graphrag/index/config/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing embeddings values."""

entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"

all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}
Loading

0 comments on commit 17658c5

Please sign in to comment.