From f0fd82eb5313f56253b7ff7c8fc7d77276f04764 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Dec 2024 13:13:17 -0500 Subject: [PATCH] Linter fixes. --- sdks/python/apache_beam/ml/rag/chunking/base.py | 13 ++++++++----- .../python/apache_beam/ml/rag/chunking/base_test.py | 11 ++++++++--- .../python/apache_beam/ml/rag/chunking/langchain.py | 12 +++++++++--- .../apache_beam/ml/rag/chunking/langchain_test.py | 8 +++----- sdks/python/apache_beam/ml/rag/embeddings/base.py | 8 +++++--- .../apache_beam/ml/rag/embeddings/base_test.py | 7 +++++-- .../apache_beam/ml/rag/embeddings/huggingface.py | 11 ++++++----- .../ml/rag/embeddings/huggingface_test.py | 10 +++++++--- sdks/python/apache_beam/ml/rag/types.py | 9 +++++++-- sdks/python/apache_beam/ml/transforms/base.py | 12 +++--------- 10 files changed, 61 insertions(+), 40 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py index ca1dfa217dd..e26d54f0dbb 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -15,13 +15,16 @@ # limitations under the License. # -import apache_beam as beam -from apache_beam.ml.transforms.base import MLTransformProvider -from apache_beam.ml.rag.types import Chunk -from typing import Optional, Dict, Any -from collections.abc import Callable import abc import functools +from collections.abc import Callable +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import MLTransformProvider ChunkIdFn = Callable[[Chunk], str] diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py index fd7c76a18b1..88456ac2134 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -17,15 +17,20 @@ """Tests for apache_beam.ml.rag.chunking.base.""" import unittest +from typing import Any +from typing import Dict +from typing import Optional + import pytest import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider, ChunkIdFn -from apache_beam.ml.rag.types import Chunk, Content -from typing import Optional, Dict, Any class WordSplitter(beam.DoFn): diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain.py b/sdks/python/apache_beam/ml/rag/chunking/langchain.py index bbe7af7edd0..8a35e7afdc6 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain.py @@ -15,11 +15,17 @@ # limitations under the License. # +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content from langchain.text_splitter import TextSplitter -from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider, ChunkIdFn -from apache_beam.ml.rag.types import Chunk, Content -from typing import List, Optional, Dict, Any class LangChainChunkingProvider(ChunkingTransformProvider): diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py index 45c96aee0db..6899696b9a0 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -19,18 +19,16 @@ import unittest import apache_beam as beam - +from apache_beam.ml.rag.types import Chunk from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.ml.rag.types import Chunk try: from apache_beam.ml.rag.chunking.langchain import LangChainChunkingProvider + from langchain.text_splitter import ( - RecursiveCharacterTextSplitter, - CharacterTextSplitter, - ) + CharacterTextSplitter, RecursiveCharacterTextSplitter) LANGCHAIN_AVAILABLE = True except ImportError: LANGCHAIN_AVAILABLE = False diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base.py b/sdks/python/apache_beam/ml/rag/embeddings/base.py index 51bad706c93..e05179458de 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base.py @@ -14,10 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from apache_beam.ml.transforms.base import EmbeddingTypeAdapter -from apache_beam.ml.rag.types import Embedding, Chunk -from typing import List from collections.abc import Sequence +from typing import List + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Embedding]: diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py index bc49420dfb5..4bfff37c92b 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -15,8 +15,11 @@ # limitations under the License. import unittest -from apache_beam.ml.rag.types import Chunk, Content, Embedding -from apache_beam.ml.rag.embeddings.base import (create_rag_adapter) + +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding class RAGBaseEmbeddingsTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index fbcfae71c44..87d7026691a 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -21,11 +21,12 @@ import apache_beam as beam from apache_beam.ml.inference.base import RunInference from apache_beam.ml.rag.embeddings.base import create_rag_adapter -from apache_beam.ml.rag.types import Chunk, Embedding -from apache_beam.ml.transforms.base import ( - EmbeddingsManager, _TextEmbeddingHandler) -from apache_beam.ml.transforms.embeddings.huggingface import ( - SentenceTransformer, _SentenceTransformerModelHandler) +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformer +from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler class HuggingfaceTextEmbeddings(EmbeddingsManager): diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py index 5076088cb96..b701830c966 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -16,16 +16,20 @@ """Tests for apache_beam.ml.rag.embeddings.huggingface.""" -import pytest import tempfile import unittest +import pytest + import apache_beam as beam from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings -from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding from apache_beam.ml.transforms.base import MLTransform from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that, equal_to +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=unused-import try: diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 72ce5c83626..405eaf1908b 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -21,9 +21,14 @@ contracts between different stages of the pipeline. """ -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Any import uuid +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple @dataclass diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 1f716b0c232..2107bf96370 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import functools import logging import os import tempfile @@ -22,8 +23,8 @@ from collections.abc import Callable from collections.abc import Mapping from collections.abc import Sequence +from dataclasses import dataclass from typing import Any -from typing import cast from typing import Dict from typing import Generic from typing import Iterable @@ -31,13 +32,11 @@ from typing import Optional from typing import TypeVar from typing import Union +from typing import cast -import functools import jsonpickle import numpy as np -from dataclasses import dataclass - import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.metrics.metric import Metrics @@ -258,11 +257,6 @@ def __init__( max_batch_size: Optional[int] = None, large_model: bool = False, **kwargs): - if columns is not None and type_adapter is not None: - raise ValueError( - "Cannot specify both 'columns' and 'type_adapter'. " - "Use either columns for dict processing or type_adapter " - "for custom types.") self.load_model_args = load_model_args or {} self.min_batch_size = min_batch_size self.max_batch_size = max_batch_size