Skip to content

Commit

Permalink
Docstring fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
claudevdm committed Dec 13, 2024
1 parent 0a28de3 commit 9637269
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 28 deletions.
33 changes: 16 additions & 17 deletions sdks/python/apache_beam/ml/rag/chunking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk):
class ChunkingTransformProvider(MLTransformProvider):
def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None):
"""Base class for chunking transforms in RAG pipelines.
ChunkingTransformProvider defines the interface for splitting documents
into chunks for embedding and retrieval. Implementations should define how
to split content while preserving metadata and managing chunk IDs.
Expand All @@ -49,24 +49,23 @@ def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None):
4. Optionally assigns unique IDs to chunks (configurable via chunk_id_fn).
Example usage:
```python
class MyChunker(ChunkingTransformProvider):
def get_splitter_transform(self):
return beam.ParDo(MySplitterDoFn())
chunker = MyChunker(chunk_id_fn=my_id_function)
with beam.Pipeline() as p:
chunks = (
p
| beam.Create([{'text': 'document...', 'source': 'doc.txt'}])
| MLTransform(...).with_transform(chunker))
```
```python
class MyChunker(ChunkingTransformProvider):
def get_splitter_transform(self):
return beam.ParDo(MySplitterDoFn())
chunker = MyChunker(chunk_id_fn=my_id_function)
with beam.Pipeline() as p:
chunks = (
p
| beam.Create([{'text': 'document...', 'source': 'doc.txt'}])
| MLTransform(...).with_transform(chunker))
```
Args:
chunk_id_fn: Optional function to generate chunk IDs. If not provided,
random UUIDs will be used. Function should take a Chunk and return
str.
chunk_id_fn: Optional function to generate chunk IDs. If not provided,
random UUIDs will be used. Function should take a Chunk and return str.
"""
self.assign_chunk_id_fn = functools.partial(
_assign_chunk_id, chunk_id_fn) if chunk_id_fn is not None else None
Expand Down
12 changes: 3 additions & 9 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,14 @@


class HuggingfaceTextEmbeddings(EmbeddingsManager):
"""SentenceTransformer embeddings for RAG pipeline.
Extends EmbeddingsManager to work with RAG-specific types:
- Input: Chunk objects containing text to embed
- Output: Chunk objects with embedding property set
"""
def __init__(
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
"""Initialize RAG embeddings.
"""Utilizes huggingface SentenceTransformer embeddings for RAG pipeline.
Args:
model_name: Name of the sentence-transformers model to use
max_seq_length: Maximum sequence length for the model
**kwargs: Additional arguments passed to parent
**kwargs: Additional arguments including ModelHandlers arguments
"""
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
self.model_name = model_name
Expand Down
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/ml/rag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
@dataclass
class Content:
"""Container for embeddable content. Add new types as when as necessary.
Args:
text: Text content to be embedded
"""
text: Optional[str] = None

Expand All @@ -42,7 +45,7 @@ class Content:
class Embedding:
"""Represents vector embeddings.
Attributes:
Args:
dense_embedding: Dense vector representation
sparse_embedding: Optional sparse vector representation for hybrid
search
Expand All @@ -56,7 +59,7 @@ class Embedding:
class Chunk:
"""Represents a chunk of embeddable content with metadata.
Attributes:
Args:
content: The actual content of the chunk
id: Unique identifier for the chunk
index: Index of this chunk within the original document
Expand Down

0 comments on commit 9637269

Please sign in to comment.