-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refact: Refactor Document
to be natively multimodal
#17204
Changes from all commits
c43534f
8f9d9cb
78329c4
38ac812
105e2f6
6cda23a
f49d7f7
46eba16
6dc41f8
bc4ac96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
|
||
import filetype | ||
from dataclasses_json import DataClassJsonMixin | ||
from deprecated import deprecated | ||
from typing_extensions import Self | ||
|
||
from llama_index.core.bridge.pydantic import ( | ||
|
@@ -48,7 +49,7 @@ | |
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from haystack.schema import Document as HaystackDocument # type: ignore | ||
from llama_cloud.types.cloud_document import CloudDocument | ||
from llama_cloud.types.cloud_document import CloudDocument # type: ignore | ||
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore | ||
|
||
from llama_index.core.bridge.langchain import Document as LCDocument # type: ignore | ||
|
@@ -303,6 +304,7 @@ class BaseNode(BaseComponent): | |
metadata_separator: str = Field( | ||
default="\n", | ||
description="Separator between metadata fields when converting to string.", | ||
alias="metadata_seperator", | ||
) | ||
|
||
@classmethod | ||
|
@@ -422,10 +424,21 @@ def ref_doc_id(self) -> Optional[str]: # pragma: no cover | |
return source_node.node_id | ||
|
||
@property | ||
def extra_info(self) -> Dict[str, Any]: # pragma: no cover | ||
"""TODO: DEPRECATED: Extra info.""" | ||
@deprecated( | ||
version="0.12.2", | ||
reason="'extra_info' is deprecated, use 'metadata' instead.", | ||
) | ||
def extra_info(self) -> dict[str, Any]: # pragma: no coverde | ||
return self.metadata | ||
|
||
@extra_info.setter | ||
@deprecated( | ||
version="0.12.2", | ||
reason="'extra_info' is deprecated, use 'metadata' instead.", | ||
) | ||
def extra_info(self, extra_info: dict[str, Any]) -> None: # pragma: no coverde | ||
self.metadata = extra_info | ||
|
||
def __str__(self) -> str: | ||
source_text_truncated = truncate_text( | ||
self.get_content().strip(), TRUNCATE_LENGTH | ||
|
@@ -540,18 +553,25 @@ def hash(self) -> str: | |
|
||
|
||
class Node(BaseNode): | ||
text: MediaResource | None = Field( | ||
text_resource: MediaResource | None = Field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this have an impact on the current openai multimodal stuff we had working? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it won't, even if the logic is almost duplicated (and we might want to refactor it later), that depends on |
||
default=None, description="Text content of the node." | ||
) | ||
image: MediaResource | None = Field( | ||
image_resource: MediaResource | None = Field( | ||
default=None, description="Image content of the node." | ||
) | ||
audio: MediaResource | None = Field( | ||
audio_resource: MediaResource | None = Field( | ||
default=None, description="Audio content of the node." | ||
) | ||
video: MediaResource | None = Field( | ||
video_resource: MediaResource | None = Field( | ||
default=None, description="Video content of the node." | ||
) | ||
text_template: str = Field( | ||
default=DEFAULT_TEXT_NODE_TMPL, | ||
description=( | ||
"Template for how text_resource is formatted, with {content} and " | ||
"{metadata_str} placeholders." | ||
), | ||
) | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
|
@@ -562,33 +582,36 @@ def get_type(cls) -> str: | |
"""Get Object type.""" | ||
return ObjectType.MULTIMODAL | ||
|
||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: | ||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: | ||
"""Get the text content for the node if available. | ||
|
||
Provided for backward compatibility, use self.text directly instead. | ||
Provided for backward compatibility, use self.text_resource directly instead. | ||
""" | ||
if self.text: | ||
return self.text.text or "" | ||
if self.text_resource: | ||
return self.text_template.format( | ||
content=self.text_resource.text or "", | ||
metadata_str=self.get_metadata_str(metadata_mode), | ||
).strip() | ||
return "" | ||
|
||
def set_content(self, value: str) -> None: | ||
"""Set the text content of the node. | ||
|
||
Provided for backward compatibility, set self.text instead. | ||
""" | ||
self.text = MediaResource(text=value) | ||
self.text_resource = MediaResource(text=value) | ||
|
||
@property | ||
def hash(self) -> str: | ||
doc_identities = [] | ||
if self.audio is not None: | ||
doc_identities.append(self.audio.hash) | ||
if self.image is not None: | ||
doc_identities.append(self.image.hash) | ||
if self.text is not None: | ||
doc_identities.append(self.text.hash) | ||
if self.video is not None: | ||
doc_identities.append(self.video.hash) | ||
if self.audio_resource is not None: | ||
doc_identities.append(self.audio_resource.hash) | ||
if self.image_resource is not None: | ||
doc_identities.append(self.image_resource.hash) | ||
if self.text_resource is not None: | ||
doc_identities.append(self.text_resource.hash) | ||
if self.video_resource is not None: | ||
doc_identities.append(self.video_resource.hash) | ||
|
||
doc_identity = "-".join(doc_identities) | ||
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) | ||
|
@@ -602,7 +625,10 @@ class TextNode(BaseNode): | |
""" | ||
|
||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
"""This is needed to help static checkers with inherited fields.""" | ||
"""Make TextNode forward-compatible with Node by supporting 'text_resource' in the constructor.""" | ||
if "text_resource" in kwargs: | ||
tr = kwargs.pop("text_resource") | ||
kwargs["text"] = tr["text"] | ||
super().__init__(*args, **kwargs) | ||
|
||
text: str = Field(default="", description="Text content of the node.") | ||
|
@@ -686,6 +712,10 @@ def get_text(self) -> str: | |
return self.get_content(metadata_mode=MetadataMode.NONE) | ||
|
||
@property | ||
@deprecated( | ||
version="0.12.2", | ||
reason="'node_info' is deprecated, use 'get_node_info' instead.", | ||
) | ||
def node_info(self) -> Dict[str, Any]: | ||
"""Deprecated: Get node info.""" | ||
return self.get_node_info() | ||
|
@@ -881,21 +911,43 @@ def get_embedding(self) -> List[float]: | |
# Document Classes for Readers | ||
|
||
|
||
class Document(TextNode): | ||
class Document(Node): | ||
"""Generic interface for a data document. | ||
|
||
This document connects to data sources. | ||
|
||
""" | ||
|
||
# TODO: A lot of backwards compatibility logic here, clean up | ||
id_: str = Field( | ||
default_factory=lambda: str(uuid.uuid4()), | ||
description="Unique ID of the node.", | ||
alias="doc_id", | ||
) | ||
def __init__(self, **data: Any) -> None: | ||
"""Keeps backward compatibility with old 'Document' versions. | ||
|
||
If 'text' was passed, store it in 'text_resource'. | ||
If 'doc_id' was passed, store it in 'id_'. | ||
If 'extra_info' was passed, store it in 'metadata'. | ||
""" | ||
if "doc_id" in data: | ||
if "id_" in data: | ||
msg = "Cannot pass both 'doc_id' and 'id_' to create a Document, use 'id_'" | ||
raise ValueError(msg) | ||
data["id_"] = data.pop("doc_id") | ||
|
||
if "extra_info" in data: | ||
if "metadata" in data: | ||
msg = "Cannot pass both 'extra_info' and 'metadata' to create a Document, use 'metadata'" | ||
raise ValueError(msg) | ||
data["metadata"] = data.pop("extra_info") | ||
|
||
if "text" in data: | ||
if "text_resource" in data: | ||
msg = "Cannot pass both 'text' and 'text_resource' to create a Document, use 'text_resource'" | ||
raise ValueError(msg) | ||
data["text_resource"] = MediaResource(text=data.pop("text")) | ||
|
||
super().__init__(**data) | ||
|
||
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"} | ||
@property | ||
def text(self) -> str: | ||
"""Provided for backward compatibility, it returns the content of text_resource.""" | ||
return self.get_content() | ||
|
||
@classmethod | ||
def get_type(cls) -> str: | ||
|
@@ -907,6 +959,10 @@ def doc_id(self) -> str: | |
"""Get document ID.""" | ||
return self.id_ | ||
|
||
@doc_id.setter | ||
def doc_id(self, id_: str) -> None: | ||
self.id_ = id_ | ||
|
||
def __str__(self) -> str: | ||
source_text_truncated = truncate_text( | ||
self.get_content().strip(), TRUNCATE_LENGTH | ||
|
@@ -916,18 +972,18 @@ def __str__(self) -> str: | |
) | ||
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" | ||
|
||
@deprecated( | ||
version="0.12.2", | ||
reason="'get_doc_id' is deprecated, access the 'id_' property instead.", | ||
) | ||
def get_doc_id(self) -> str: | ||
"""TODO: Deprecated: Get document ID.""" | ||
return self.id_ | ||
|
||
def __setattr__(self, name: str, value: object) -> None: | ||
if name in self._compat_fields: | ||
name = self._compat_fields[name] | ||
super().__setattr__(name, value) | ||
|
||
def to_langchain_format(self) -> LCDocument: | ||
"""Convert struct to LangChain document format.""" | ||
from llama_index.core.bridge.langchain import Document as LCDocument | ||
from llama_index.core.bridge.langchain import ( | ||
Document as LCDocument, # type: ignore | ||
) | ||
|
||
metadata = self.metadata or {} | ||
return LCDocument(page_content=self.text, metadata=metadata, id=self.id_) | ||
|
@@ -941,7 +997,7 @@ def from_langchain_format(cls, doc: LCDocument) -> Document: | |
|
||
def to_haystack_format(self) -> HaystackDocument: | ||
"""Convert struct to Haystack document format.""" | ||
from haystack.schema import Document as HaystackDocument | ||
from haystack import Document as HaystackDocument # type: ignore | ||
|
||
return HaystackDocument( | ||
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ | ||
|
@@ -973,7 +1029,7 @@ def from_embedchain_format(cls, doc: Dict[str, Any]) -> Document: | |
def to_semantic_kernel_format(self) -> MemoryRecord: | ||
"""Convert struct to Semantic Kernel document format.""" | ||
import numpy as np | ||
from semantic_kernel.memory.memory_record import MemoryRecord | ||
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore | ||
|
||
return MemoryRecord( | ||
id=self.id_, | ||
|
@@ -1015,7 +1071,7 @@ def class_name(cls) -> str: | |
|
||
def to_cloud_document(self) -> CloudDocument: | ||
"""Convert to LlamaCloud document type.""" | ||
from llama_cloud.types.cloud_document import CloudDocument | ||
from llama_cloud.types.cloud_document import CloudDocument # type: ignore | ||
|
||
return CloudDocument( | ||
text=self.text, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ ModelScopeLLM = "llama-index" | |
disallow_untyped_defs = true | ||
exclude = ["_static", "build", "examples", "notebooks", "venv"] | ||
ignore_missing_imports = true | ||
python_version = "3.8" | ||
python_version = "3.9" | ||
|
||
[tool.poetry] | ||
authors = ["ModelScope <[email protected]>"] | ||
|
@@ -30,7 +30,7 @@ readme = "README.md" | |
version = "0.4.1" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.9,<3.12" | ||
python = ">=3.9,<4.0" | ||
modelscope = {extras = ["framework"], version = ">=1.12.0"} | ||
torch = "^2.1.2" | ||
llama-index-core = "^0.12.0" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol