Skip to content

Commit

Permalink
refact: Refactor Document to be natively multimodal (#17204)
Browse files Browse the repository at this point in the history
* rename resource fields

* refactor Document

* fix typing, bring back text_template for backward compat

* fix bug in keyval docstore

* make TextNode forward-compatible

* redo deprecations

* fix model identifier

* update mocks

* update mocks

* fix fixture check
  • Loading branch information
masci authored Dec 11, 2024
1 parent 601d09d commit b004ea0
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 85 deletions.
3 changes: 1 addition & 2 deletions llama-index-core/llama_index/core/node_parser/node_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""General node utils."""


import logging
import uuid
from typing import List, Optional, Protocol, runtime_checkable
Expand Down Expand Up @@ -68,7 +67,7 @@ def build_nodes_from_splits(
embedding=document.embedding,
excluded_embed_metadata_keys=document.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=document.excluded_llm_metadata_keys,
metadata_seperator=document.metadata_seperator,
metadata_seperator=document.metadata_separator,
metadata_template=document.metadata_template,
text_template=document.text_template,
relationships=relationships,
Expand Down
136 changes: 96 additions & 40 deletions llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import filetype
from dataclasses_json import DataClassJsonMixin
from deprecated import deprecated
from typing_extensions import Self

from llama_index.core.bridge.pydantic import (
Expand All @@ -48,7 +49,7 @@

if TYPE_CHECKING: # pragma: no cover
from haystack.schema import Document as HaystackDocument # type: ignore
from llama_cloud.types.cloud_document import CloudDocument
from llama_cloud.types.cloud_document import CloudDocument # type: ignore
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore

from llama_index.core.bridge.langchain import Document as LCDocument # type: ignore
Expand Down Expand Up @@ -303,6 +304,7 @@ class BaseNode(BaseComponent):
metadata_separator: str = Field(
default="\n",
description="Separator between metadata fields when converting to string.",
alias="metadata_seperator",
)

@classmethod
Expand Down Expand Up @@ -422,10 +424,21 @@ def ref_doc_id(self) -> Optional[str]: # pragma: no cover
return source_node.node_id

@property
def extra_info(self) -> Dict[str, Any]: # pragma: no cover
"""TODO: DEPRECATED: Extra info."""
@deprecated(
version="0.12.2",
reason="'extra_info' is deprecated, use 'metadata' instead.",
)
def extra_info(self) -> dict[str, Any]: # pragma: no coverde
return self.metadata

@extra_info.setter
@deprecated(
version="0.12.2",
reason="'extra_info' is deprecated, use 'metadata' instead.",
)
def extra_info(self, extra_info: dict[str, Any]) -> None: # pragma: no coverde
self.metadata = extra_info

def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
Expand Down Expand Up @@ -540,18 +553,25 @@ def hash(self) -> str:


class Node(BaseNode):
text: MediaResource | None = Field(
text_resource: MediaResource | None = Field(
default=None, description="Text content of the node."
)
image: MediaResource | None = Field(
image_resource: MediaResource | None = Field(
default=None, description="Image content of the node."
)
audio: MediaResource | None = Field(
audio_resource: MediaResource | None = Field(
default=None, description="Audio content of the node."
)
video: MediaResource | None = Field(
video_resource: MediaResource | None = Field(
default=None, description="Video content of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text_resource is formatted, with {content} and "
"{metadata_str} placeholders."
),
)

@classmethod
def class_name(cls) -> str:
Expand All @@ -562,33 +582,36 @@ def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.MULTIMODAL

def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get the text content for the node if available.
Provided for backward compatibility, use self.text directly instead.
Provided for backward compatibility, use self.text_resource directly instead.
"""
if self.text:
return self.text.text or ""
if self.text_resource:
return self.text_template.format(
content=self.text_resource.text or "",
metadata_str=self.get_metadata_str(metadata_mode),
).strip()
return ""

def set_content(self, value: str) -> None:
"""Set the text content of the node.
Provided for backward compatibility, set self.text instead.
"""
self.text = MediaResource(text=value)
self.text_resource = MediaResource(text=value)

@property
def hash(self) -> str:
doc_identities = []
if self.audio is not None:
doc_identities.append(self.audio.hash)
if self.image is not None:
doc_identities.append(self.image.hash)
if self.text is not None:
doc_identities.append(self.text.hash)
if self.video is not None:
doc_identities.append(self.video.hash)
if self.audio_resource is not None:
doc_identities.append(self.audio_resource.hash)
if self.image_resource is not None:
doc_identities.append(self.image_resource.hash)
if self.text_resource is not None:
doc_identities.append(self.text_resource.hash)
if self.video_resource is not None:
doc_identities.append(self.video_resource.hash)

doc_identity = "-".join(doc_identities)
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())
Expand All @@ -602,7 +625,10 @@ class TextNode(BaseNode):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""This is needed to help static checkers with inherited fields."""
"""Make TextNode forward-compatible with Node by supporting 'text_resource' in the constructor."""
if "text_resource" in kwargs:
tr = kwargs.pop("text_resource")
kwargs["text"] = tr["text"]
super().__init__(*args, **kwargs)

text: str = Field(default="", description="Text content of the node.")
Expand Down Expand Up @@ -686,6 +712,10 @@ def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)

@property
@deprecated(
version="0.12.2",
reason="'node_info' is deprecated, use 'get_node_info' instead.",
)
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
Expand Down Expand Up @@ -881,21 +911,43 @@ def get_embedding(self) -> List[float]:
# Document Classes for Readers


class Document(TextNode):
class Document(Node):
"""Generic interface for a data document.
This document connects to data sources.
"""

# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
def __init__(self, **data: Any) -> None:
"""Keeps backward compatibility with old 'Document' versions.
If 'text' was passed, store it in 'text_resource'.
If 'doc_id' was passed, store it in 'id_'.
If 'extra_info' was passed, store it in 'metadata'.
"""
if "doc_id" in data:
if "id_" in data:
msg = "Cannot pass both 'doc_id' and 'id_' to create a Document, use 'id_'"
raise ValueError(msg)
data["id_"] = data.pop("doc_id")

if "extra_info" in data:
if "metadata" in data:
msg = "Cannot pass both 'extra_info' and 'metadata' to create a Document, use 'metadata'"
raise ValueError(msg)
data["metadata"] = data.pop("extra_info")

if "text" in data:
if "text_resource" in data:
msg = "Cannot pass both 'text' and 'text_resource' to create a Document, use 'text_resource'"
raise ValueError(msg)
data["text_resource"] = MediaResource(text=data.pop("text"))

super().__init__(**data)

_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@property
def text(self) -> str:
"""Provided for backward compatibility, it returns the content of text_resource."""
return self.get_content()

@classmethod
def get_type(cls) -> str:
Expand All @@ -907,6 +959,10 @@ def doc_id(self) -> str:
"""Get document ID."""
return self.id_

@doc_id.setter
def doc_id(self, id_: str) -> None:
self.id_ = id_

def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
Expand All @@ -916,18 +972,18 @@ def __str__(self) -> str:
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"

@deprecated(
version="0.12.2",
reason="'get_doc_id' is deprecated, access the 'id_' property instead.",
)
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_

def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)

def to_langchain_format(self) -> LCDocument:
"""Convert struct to LangChain document format."""
from llama_index.core.bridge.langchain import Document as LCDocument
from llama_index.core.bridge.langchain import (
Document as LCDocument, # type: ignore
)

metadata = self.metadata or {}
return LCDocument(page_content=self.text, metadata=metadata, id=self.id_)
Expand All @@ -941,7 +997,7 @@ def from_langchain_format(cls, doc: LCDocument) -> Document:

def to_haystack_format(self) -> HaystackDocument:
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
from haystack import Document as HaystackDocument # type: ignore

return HaystackDocument(
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
Expand Down Expand Up @@ -973,7 +1029,7 @@ def from_embedchain_format(cls, doc: Dict[str, Any]) -> Document:
def to_semantic_kernel_format(self) -> MemoryRecord:
"""Convert struct to Semantic Kernel document format."""
import numpy as np
from semantic_kernel.memory.memory_record import MemoryRecord
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore

return MemoryRecord(
id=self.id_,
Expand Down Expand Up @@ -1015,7 +1071,7 @@ def class_name(cls) -> str:

def to_cloud_document(self) -> CloudDocument:
"""Convert to LlamaCloud document type."""
from llama_cloud.types.cloud_document import CloudDocument
from llama_cloud.types.cloud_document import CloudDocument # type: ignore

return CloudDocument(
text=self.text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from typing import Any, Dict, List, Optional, Sequence, Tuple

from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.schema import BaseNode, Document, TextNode
from llama_index.core.storage.docstore.types import BaseDocumentStore, RefDocInfo
from llama_index.core.storage.docstore.utils import doc_to_json, json_to_doc
from llama_index.core.storage.kvstore.types import DEFAULT_BATCH_SIZE, BaseKVStore
Expand Down Expand Up @@ -176,7 +176,7 @@ def _prepare_kv_pairs(
"Set allow_update to True to overwrite."
)
ref_doc_info = None
if isinstance(node, TextNode) and node.ref_doc_id is not None:
if isinstance(node, (TextNode, Document)) and node.ref_doc_id is not None:
ref_doc_info = self.get_ref_doc_info(node.ref_doc_id) or RefDocInfo()

(
Expand Down
8 changes: 4 additions & 4 deletions llama-index-core/tests/schema/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def test_get_content():

def test_hash():
node = Node()
node.audio = MediaResource(data=b"test audio", mimetype="audio/aac")
node.image = MediaResource(data=b"test image", mimetype="image/png")
node.text = MediaResource(text="some text", mimetype="text/plain")
node.video = MediaResource(data=b"some video", mimetype="video/mpeg")
node.audio_resource = MediaResource(data=b"test audio", mimetype="audio/aac")
node.image_resource = MediaResource(data=b"test image", mimetype="image/png")
node.text_resource = MediaResource(text="some text", mimetype="text/plain")
node.video_resource = MediaResource(data=b"some video", mimetype="video/mpeg")
assert (
node.hash == "ee411edd3dffb27470eef165ccf4df9fabaa02e7c7c39415950d3ac4d7e35e61"
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Optional, Sequence

import torch
from llama_index.core.base.llms.types import (
ChatMessage,
Expand All @@ -24,13 +25,13 @@
from llama_index.core.types import PydanticProgramMode
from llama_index.llms.modelscope.utils import (
chat_message_to_modelscope_messages,
text_to_completion_response,
modelscope_message_to_chat_response,
text_to_completion_response,
)

from modelscope.pipelines import pipeline as pipeline_builder

DEFAULT_MODELSCOPE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
DEFAULT_MODELSCOPE_MODEL = "qwen/Qwen2-0.5B-Instruct"
DEFAULT_MODELSCOPE_MODEL_REVISION = "master"
DEFAULT_MODELSCOPE_TASK = "chat"
DEFAULT_MODELSCOPE_DTYPE = "float16"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ModelScopeLLM = "llama-index"
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
python_version = "3.9"

[tool.poetry]
authors = ["ModelScope <[email protected]>"]
Expand All @@ -30,7 +30,7 @@ readme = "README.md"
version = "0.4.1"

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
python = ">=3.9,<4.0"
modelscope = {extras = ["framework"], version = ">=1.12.0"}
torch = "^2.1.2"
llama-index-core = "^0.12.0"
Expand Down
Loading

0 comments on commit b004ea0

Please sign in to comment.