Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: Refactor Document to be natively multimodal #17204

Merged
merged 10 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions llama-index-core/llama_index/core/node_parser/node_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""General node utils."""


import logging
import uuid
from typing import List, Optional, Protocol, runtime_checkable
Expand Down Expand Up @@ -68,7 +67,7 @@ def build_nodes_from_splits(
embedding=document.embedding,
excluded_embed_metadata_keys=document.excluded_embed_metadata_keys,
excluded_llm_metadata_keys=document.excluded_llm_metadata_keys,
metadata_seperator=document.metadata_seperator,
metadata_seperator=document.metadata_separator,
metadata_template=document.metadata_template,
text_template=document.text_template,
relationships=relationships,
Expand Down
136 changes: 96 additions & 40 deletions llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import filetype
from dataclasses_json import DataClassJsonMixin
from deprecated import deprecated
from typing_extensions import Self

from llama_index.core.bridge.pydantic import (
Expand All @@ -48,7 +49,7 @@

if TYPE_CHECKING: # pragma: no cover
from haystack.schema import Document as HaystackDocument # type: ignore
from llama_cloud.types.cloud_document import CloudDocument
from llama_cloud.types.cloud_document import CloudDocument # type: ignore
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore

from llama_index.core.bridge.langchain import Document as LCDocument # type: ignore
Expand Down Expand Up @@ -303,6 +304,7 @@ class BaseNode(BaseComponent):
metadata_separator: str = Field(
default="\n",
description="Separator between metadata fields when converting to string.",
alias="metadata_seperator",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

)

@classmethod
Expand Down Expand Up @@ -422,10 +424,21 @@ def ref_doc_id(self) -> Optional[str]: # pragma: no cover
return source_node.node_id

@property
def extra_info(self) -> Dict[str, Any]: # pragma: no cover
"""TODO: DEPRECATED: Extra info."""
@deprecated(
version="0.12.2",
reason="'extra_info' is deprecated, use 'metadata' instead.",
)
def extra_info(self) -> dict[str, Any]: # pragma: no coverde
return self.metadata

@extra_info.setter
@deprecated(
version="0.12.2",
reason="'extra_info' is deprecated, use 'metadata' instead.",
)
def extra_info(self, extra_info: dict[str, Any]) -> None: # pragma: no coverde
self.metadata = extra_info

def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
Expand Down Expand Up @@ -540,18 +553,25 @@ def hash(self) -> str:


class Node(BaseNode):
text: MediaResource | None = Field(
text_resource: MediaResource | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this have an impact on the current openai multimodal stuff we had working?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it won't, even if the logic is almost duplicated (and we might want to refactor it later), that depends on ChatMessage and not Node.

default=None, description="Text content of the node."
)
image: MediaResource | None = Field(
image_resource: MediaResource | None = Field(
default=None, description="Image content of the node."
)
audio: MediaResource | None = Field(
audio_resource: MediaResource | None = Field(
default=None, description="Audio content of the node."
)
video: MediaResource | None = Field(
video_resource: MediaResource | None = Field(
default=None, description="Video content of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text_resource is formatted, with {content} and "
"{metadata_str} placeholders."
),
)

@classmethod
def class_name(cls) -> str:
Expand All @@ -562,33 +582,36 @@ def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.MULTIMODAL

def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get the text content for the node if available.

Provided for backward compatibility, use self.text directly instead.
Provided for backward compatibility, use self.text_resource directly instead.
"""
if self.text:
return self.text.text or ""
if self.text_resource:
return self.text_template.format(
content=self.text_resource.text or "",
metadata_str=self.get_metadata_str(metadata_mode),
).strip()
return ""

def set_content(self, value: str) -> None:
"""Set the text content of the node.

Provided for backward compatibility, set self.text instead.
"""
self.text = MediaResource(text=value)
self.text_resource = MediaResource(text=value)

@property
def hash(self) -> str:
doc_identities = []
if self.audio is not None:
doc_identities.append(self.audio.hash)
if self.image is not None:
doc_identities.append(self.image.hash)
if self.text is not None:
doc_identities.append(self.text.hash)
if self.video is not None:
doc_identities.append(self.video.hash)
if self.audio_resource is not None:
doc_identities.append(self.audio_resource.hash)
if self.image_resource is not None:
doc_identities.append(self.image_resource.hash)
if self.text_resource is not None:
doc_identities.append(self.text_resource.hash)
if self.video_resource is not None:
doc_identities.append(self.video_resource.hash)

doc_identity = "-".join(doc_identities)
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())
Expand All @@ -602,7 +625,10 @@ class TextNode(BaseNode):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""This is needed to help static checkers with inherited fields."""
"""Make TextNode forward-compatible with Node by supporting 'text_resource' in the constructor."""
if "text_resource" in kwargs:
tr = kwargs.pop("text_resource")
kwargs["text"] = tr["text"]
super().__init__(*args, **kwargs)

text: str = Field(default="", description="Text content of the node.")
Expand Down Expand Up @@ -686,6 +712,10 @@ def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)

@property
@deprecated(
version="0.12.2",
reason="'node_info' is deprecated, use 'get_node_info' instead.",
)
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
Expand Down Expand Up @@ -881,21 +911,43 @@ def get_embedding(self) -> List[float]:
# Document Classes for Readers


class Document(TextNode):
class Document(Node):
"""Generic interface for a data document.

This document connects to data sources.

"""

# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
def __init__(self, **data: Any) -> None:
"""Keeps backward compatibility with old 'Document' versions.

If 'text' was passed, store it in 'text_resource'.
If 'doc_id' was passed, store it in 'id_'.
If 'extra_info' was passed, store it in 'metadata'.
"""
if "doc_id" in data:
if "id_" in data:
msg = "Cannot pass both 'doc_id' and 'id_' to create a Document, use 'id_'"
raise ValueError(msg)
data["id_"] = data.pop("doc_id")

if "extra_info" in data:
if "metadata" in data:
msg = "Cannot pass both 'extra_info' and 'metadata' to create a Document, use 'metadata'"
raise ValueError(msg)
data["metadata"] = data.pop("extra_info")

if "text" in data:
if "text_resource" in data:
msg = "Cannot pass both 'text' and 'text_resource' to create a Document, use 'text_resource'"
raise ValueError(msg)
data["text_resource"] = MediaResource(text=data.pop("text"))

super().__init__(**data)

_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@property
def text(self) -> str:
"""Provided for backward compatibility, it returns the content of text_resource."""
return self.get_content()

@classmethod
def get_type(cls) -> str:
Expand All @@ -907,6 +959,10 @@ def doc_id(self) -> str:
"""Get document ID."""
return self.id_

@doc_id.setter
def doc_id(self, id_: str) -> None:
self.id_ = id_

def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
Expand All @@ -916,18 +972,18 @@ def __str__(self) -> str:
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"

@deprecated(
version="0.12.2",
reason="'get_doc_id' is deprecated, access the 'id_' property instead.",
)
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_

def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)

def to_langchain_format(self) -> LCDocument:
"""Convert struct to LangChain document format."""
from llama_index.core.bridge.langchain import Document as LCDocument
from llama_index.core.bridge.langchain import (
Document as LCDocument, # type: ignore
)

metadata = self.metadata or {}
return LCDocument(page_content=self.text, metadata=metadata, id=self.id_)
Expand All @@ -941,7 +997,7 @@ def from_langchain_format(cls, doc: LCDocument) -> Document:

def to_haystack_format(self) -> HaystackDocument:
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
from haystack import Document as HaystackDocument # type: ignore

return HaystackDocument(
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
Expand Down Expand Up @@ -973,7 +1029,7 @@ def from_embedchain_format(cls, doc: Dict[str, Any]) -> Document:
def to_semantic_kernel_format(self) -> MemoryRecord:
"""Convert struct to Semantic Kernel document format."""
import numpy as np
from semantic_kernel.memory.memory_record import MemoryRecord
from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore

return MemoryRecord(
id=self.id_,
Expand Down Expand Up @@ -1015,7 +1071,7 @@ def class_name(cls) -> str:

def to_cloud_document(self) -> CloudDocument:
"""Convert to LlamaCloud document type."""
from llama_cloud.types.cloud_document import CloudDocument
from llama_cloud.types.cloud_document import CloudDocument # type: ignore

return CloudDocument(
text=self.text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from typing import Any, Dict, List, Optional, Sequence, Tuple

from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.schema import BaseNode, Document, TextNode
from llama_index.core.storage.docstore.types import BaseDocumentStore, RefDocInfo
from llama_index.core.storage.docstore.utils import doc_to_json, json_to_doc
from llama_index.core.storage.kvstore.types import DEFAULT_BATCH_SIZE, BaseKVStore
Expand Down Expand Up @@ -176,7 +176,7 @@ def _prepare_kv_pairs(
"Set allow_update to True to overwrite."
)
ref_doc_info = None
if isinstance(node, TextNode) and node.ref_doc_id is not None:
if isinstance(node, (TextNode, Document)) and node.ref_doc_id is not None:
ref_doc_info = self.get_ref_doc_info(node.ref_doc_id) or RefDocInfo()

(
Expand Down
8 changes: 4 additions & 4 deletions llama-index-core/tests/schema/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def test_get_content():

def test_hash():
node = Node()
node.audio = MediaResource(data=b"test audio", mimetype="audio/aac")
node.image = MediaResource(data=b"test image", mimetype="image/png")
node.text = MediaResource(text="some text", mimetype="text/plain")
node.video = MediaResource(data=b"some video", mimetype="video/mpeg")
node.audio_resource = MediaResource(data=b"test audio", mimetype="audio/aac")
node.image_resource = MediaResource(data=b"test image", mimetype="image/png")
node.text_resource = MediaResource(text="some text", mimetype="text/plain")
node.video_resource = MediaResource(data=b"some video", mimetype="video/mpeg")
assert (
node.hash == "ee411edd3dffb27470eef165ccf4df9fabaa02e7c7c39415950d3ac4d7e35e61"
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Optional, Sequence

import torch
from llama_index.core.base.llms.types import (
ChatMessage,
Expand All @@ -24,13 +25,13 @@
from llama_index.core.types import PydanticProgramMode
from llama_index.llms.modelscope.utils import (
chat_message_to_modelscope_messages,
text_to_completion_response,
modelscope_message_to_chat_response,
text_to_completion_response,
)

from modelscope.pipelines import pipeline as pipeline_builder

DEFAULT_MODELSCOPE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
DEFAULT_MODELSCOPE_MODEL = "qwen/Qwen2-0.5B-Instruct"
DEFAULT_MODELSCOPE_MODEL_REVISION = "master"
DEFAULT_MODELSCOPE_TASK = "chat"
DEFAULT_MODELSCOPE_DTYPE = "float16"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ModelScopeLLM = "llama-index"
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
python_version = "3.9"

[tool.poetry]
authors = ["ModelScope <[email protected]>"]
Expand All @@ -30,7 +30,7 @@ readme = "README.md"
version = "0.4.1"

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
python = ">=3.9,<4.0"
modelscope = {extras = ["framework"], version = ">=1.12.0"}
torch = "^2.1.2"
llama-index-core = "^0.12.0"
Expand Down
Loading
Loading