Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update artifacts fetcher to download artifacts locally using FileSystems #30202

Merged
merged 10 commits into from
Feb 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_process_large_movie_review_dataset(self):

artifacts_fetcher = ArtifactsFetcher(artifact_location=artifact_location)

actual_vocab_list = artifacts_fetcher.get_vocab_list()
vocab_filename = f'vocab_{vocab_tfidf_processing.REVIEW_COLUMN}'
actual_vocab_list = artifacts_fetcher.get_vocab_list(
vocab_filename=vocab_filename)

expected_artifact_filepath = 'gs://apache-beam-ml/testing/expected_outputs/compute_and_apply_vocab' # pylint: disable=line-too-long

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ def preprocess_data(
top_k=VOCAB_SIZE,
frequency_threshold=10,
columns=[REVIEW_COLUMN],
vocab_filename='vocab',
split_string_by_delimiter=DELIMITERS)).with_transform(
TFIDF(columns=[REVIEW_COLUMN], vocab_size=VOCAB_SIZE))
TFIDF(
columns=[REVIEW_COLUMN],
vocab_size=VOCAB_SIZE,
))
data_pcoll = data_pcoll | 'MLTransform' >> ml_transform

data_pcoll = (
Expand Down
43 changes: 34 additions & 9 deletions sdks/python/apache_beam/ml/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,47 @@
__all__ = ['ArtifactsFetcher']

import os
import tempfile
import typing

from google.cloud.storage import Client
from google.cloud.storage import transfer_manager

import tensorflow_transform as tft
from apache_beam.ml.transforms import base


class ArtifactsFetcher():
def download_artifacts_from_gcs(bucket_name, prefix, local_path):
"""Downloads artifacts from GCS to the local file system.
Args:
bucket_name: The name of the GCS bucket to download from.
prefix: Prefix of GCS objects to download.
local_path: The local path to download the folder to.
"""
client = Client()
bucket = client.get_bucket(bucket_name)
blobs = [blob.name for blob in bucket.list_blobs(prefix=prefix)]
_ = transfer_manager.download_many_to_path(
bucket, blobs, destination_directory=local_path)


class ArtifactsFetcher:
"""
Utility class used to fetch artifacts from the artifact_location passed
to the TFTProcessHandlers in MLTransform.

This is intended to be used for testing purposes only.
"""
def __init__(self, artifact_location):
def __init__(self, artifact_location: str):
tempdir = tempfile.mkdtemp()
if artifact_location.startswith('gs://'):
parts = artifact_location[5:].split('/')
bucket_name = parts[0]
prefix = '/'.join(parts[1:])
download_artifacts_from_gcs(bucket_name, prefix, tempdir)

assert os.listdir(tempdir), f"No files found in {artifact_location}"
artifact_location = os.path.join(tempdir, prefix)
files = os.listdir(artifact_location)
files.remove(base._ATTRIBUTE_FILE_NAME)
# TODO: https://github.com/apache/beam/issues/29356
Expand All @@ -43,9 +72,7 @@ def __init__(self, artifact_location):
self._artifact_location = os.path.join(artifact_location, files[0])
self.transform_output = tft.TFTransformOutput(self._artifact_location)

def get_vocab_list(
self,
vocab_filename: str = 'compute_and_apply_vocab') -> typing.List[bytes]:
def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]:
"""
Returns list of vocabulary terms created during MLTransform.
"""
Expand All @@ -57,13 +84,11 @@ def get_vocab_list(
vocab_filename)) from e
return [x.decode('utf-8') for x in vocab_list]

def get_vocab_filepath(
self, vocab_filename: str = 'compute_and_apply_vocab') -> str:
def get_vocab_filepath(self, vocab_filename: str) -> str:
"""
Return the path to the vocabulary file created during MLTransform.
"""
return self.transform_output.vocabulary_file_by_name(vocab_filename)

def get_vocab_size(
self, vocab_filename: str = 'compute_and_apply_vocab') -> int:
def get_vocab_size(self, vocab_filename: str) -> int:
return self.transform_output.vocabulary_size_by_name(vocab_filename)
Loading