Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parallelizing processing parquet files across workers and nodes. #19400

Merged
merged 17 commits into from
Feb 5, 2024
2 changes: 2 additions & 0 deletions requirements/data/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ pytest-timeout ==2.1.0
pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
viztracer
pyarrow
polars
Empty file.
47 changes: 47 additions & 0 deletions src/lightning/data/processing/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import io
from typing import Optional, Tuple

from lightning_utilities.core.imports import RequirementCache

_HTTPX_AVAILABLE = RequirementCache("httpx")

# Credit to the https://github.com/rom1504/img2dataset Github repo
# The code was taken from there. It has a MIT License.

def _download_image(
url: str,
timeout: int = 10,
user_agent_token: str = "pytorch-lightning",
) -> Tuple[Optional[io.BytesIO], Optional[Exception]]:
"""Download an image with urllib."""
url
img_stream = None
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)"
import httpx

try:
with httpx.Client(http2=True) as client:
r = client.get(url, headers={"User-Agent": user_agent_string}, timeout=timeout)
img_stream = io.BytesIO(r.read())
return img_stream, None
except Exception as err: # pylint: disable=broad-except
if img_stream is not None:
img_stream.close()
return None, err


def download_image(
url: str,
retries: int = 0,
timeout: int = 10,
user_agent_token: str = "pytorch-lightning",
) -> Tuple[Optional[io.BytesIO], Optional[Exception]]:
if not _HTTPX_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install httpx`.")
for _ in range(retries + 1):
img_stream, err = _download_image(url, timeout, user_agent_token)
if img_stream is not None:
return img_stream, err
return None, err
120 changes: 120 additions & 0 deletions src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional

from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks
from lightning.data.utilities.env import _DistributedEnv

_POLARS_AVAILABLE = RequirementCache("polars")
_PYARROW_AVAILABLE = RequirementCache("pyarrow")


class BaseReader(ABC):

def get_num_nodes(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1))

@abstractmethod
def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]:
"""This method is meant to convert the items provided by the users into items to be processed by the
workers."""
pass

@abstractmethod
def read(self, item: Any) -> Any:
"""Read the data associated to an item."""
pass


@dataclass
class ParquetSlice:
"""Keep track of a parquet file slice with its filepath, start and end."""
filepath: str
start: int
end: int


class ParquetReader(BaseReader):

def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None:
self.num_rows = num_rows
self.to_pandas = to_pandas

if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`")

def _get_num_rows(self, path: str) -> int:
if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds
df = ds.dataset(path).scanner()
return df.count_rows()

# FIXED: There is a bug in polars. This leads to read_parquet to hang.
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(path)
num_rows = df.select(pol.len()).collect().item()
return num_rows

raise RuntimeError("Please, install either pyarrow or polars.")

def read(self, item: ParquetSlice) -> Any:
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect()

if self.to_pandas:
df = df.to_pandas()

return df

if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds

df = ds.dataset(item.filepath).scanner()

df = df.take([item.start, item.end])

if self.to_pandas:
df.to_pandas()

return df

raise RuntimeError("Please, install either pyarrow or polars.")


def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]:
intervals = [(0, self._get_num_rows(item)) for item in items]

world_size = self.get_num_nodes() * num_workers

fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes())
parquet_indexes_per_worker, parquet_slices_per_worker = _associate_chunks_and_internals_to_ranks(
fake_distributed_env, list(range(len(items))), intervals, False)

workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(world_size)]

iterator = enumerate(zip(parquet_indexes_per_worker, parquet_slices_per_worker))

for worker_idx, (parquet_indexes, parquet_slices) in iterator:
if self.num_rows:
workers_user_items[worker_idx].extend([
ParquetSlice(
items[parquet_index], parquet_slice_start, parquet_slice_start + self.num_rows
if parquet_slice[1] > (parquet_slice_start + self.num_rows) else
parquet_slice[1]
)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
for parquet_slice_start in range(parquet_slice[0], parquet_slice[1] + self.num_rows, self.num_rows)
if parquet_slice_start < parquet_slice[1]
])
else:
workers_user_items[worker_idx].extend([
ParquetSlice(items[parquet_index], *parquet_slice)
for parquet_index, parquet_slice in zip(parquet_indexes, parquet_slices)
])

return workers_user_items
38 changes: 27 additions & 11 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
Expand Down Expand Up @@ -158,8 +159,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)

elif os.path.isfile(path):
os.makedirs(os.path.dirname(local_path), exist_ok=True)
shutil.copyfile(path, local_path)
if not path.startswith("/teamspace/studios/this_studio"):
os.makedirs(os.path.dirname(local_path), exist_ok=True)
shutil.copyfile(path, local_path)
else:
raise ValueError(f"The provided {input_dir.url} isn't supported.")

Expand Down Expand Up @@ -340,6 +342,7 @@ def __init__(
num_downloaders: int,
num_uploaders: int,
remove: bool,
reader: Optional[BaseReader] = None,
) -> None:
"""The BaseWorker is responsible to process the user data."""
self.worker_index = worker_index
Expand All @@ -353,6 +356,7 @@ def __init__(
self.num_downloaders = num_downloaders
self.num_uploaders = num_uploaders
self.remove = remove
self.reader = reader
self.paths: List[List[str]] = []
self.remover: Optional[Process] = None
self.downloaders: List[Process] = []
Expand Down Expand Up @@ -433,7 +437,7 @@ def _loop(self) -> None:
self.progress_queue.put((self.worker_index, self._counter))
self._last_time = time()

if self.remove and self.input_dir.path is not None:
if self.remove and self.input_dir.path is not None and self.reader is None:
self.remove_queue.put(self.paths[index])

try:
Expand Down Expand Up @@ -476,7 +480,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
self.to_upload_queues[self._counter % self.num_uploaders].put(data)

def _collect_paths(self) -> None:
if self.input_dir.path is None:
if self.input_dir.path is None or self.reader is not None:
for index in range(len(self.items)):
self.ready_to_process_queue.put(index)
for _ in range(self.num_downloaders):
Expand Down Expand Up @@ -513,7 +517,7 @@ def is_path(element: Any) -> bool:
paths = []
for index, path in indexed_paths.items():
paths.append(path)
if self.input_dir:
if self.input_dir and not self.input_dir.path.startswith("/teamspace/studios/this_studio"):
path = path.replace(self.input_dir.path, self.cache_data_dir)
flattened_item[index] = path

Expand All @@ -525,8 +529,9 @@ def is_path(element: Any) -> bool:
self.items = items

def _start_downloaders(self) -> None:
if self.input_dir.path is None:
if self.input_dir.path is None or self.reader is not None:
return

for _ in range(self.num_downloaders):
to_download_queue: Queue = Queue()
p = Process(
Expand Down Expand Up @@ -583,7 +588,7 @@ def _start_uploaders(self) -> None:

def _handle_data_chunk_recipe(self, index: int) -> None:
try:
self._current_item = self.items[index]
self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data_or_generator = self.data_recipe.prepare_item(self._current_item)
if isinstance(item_data_or_generator, types.GeneratorType):
for item_data in item_data_or_generator:
Expand All @@ -596,7 +601,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
self._try_upload(chunk_filepath)
self._index_counter += 1
except Exception as e:
raise RuntimeError(f"Failed processing {self._current_item}") from e
raise RuntimeError(f"Failed processing {self.items[index]}") from e

def _handle_data_chunk_recipe_end(self) -> None:
chunks_filepaths = self.cache.done()
Expand All @@ -609,7 +614,8 @@ def _handle_data_chunk_recipe_end(self) -> None:
def _handle_data_transform_recipe(self, index: int) -> None:
# Don't use a context manager to avoid deleting files that are being uploaded.
output_dir = tempfile.mkdtemp()
item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir), len(self.items) - 1 == index)
item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data = self.data_recipe.prepare_item(item, str(output_dir), len(self.items) - 1 == index)
if item_data is not None:
raise ValueError(
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
Expand Down Expand Up @@ -792,6 +798,7 @@ def __init__(
random_seed: Optional[int] = 42,
reorder_files: bool = True,
weights: Optional[List[int]] = None,
reader: Optional[BaseReader] = None,
):
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
training faster.
Expand All @@ -809,6 +816,7 @@ def __init__(
Set this to ``False`` if the order in which samples are processed should be preserved.
weights: Provide a list of weights associated to the inputs.
This is used to evenly split the work among the workers.
reader: Map the inputs to worker inputs and provides a read method to read a slice of the data.

"""
self.input_dir = _resolve_dir(input_dir)
Expand All @@ -825,6 +833,10 @@ def __init__(
self.stop_queues: List[Queue] = []
self.reorder_files = reorder_files
self.weights = weights
self.reader = reader

if self.reader is not None and self.weights is not None:
raise ValueError("Either the reader or the weights needs to be defined.")

# Ensure the input dir is the same across all nodes
self.input_dir = broadcast_object("input_dir", self.input_dir)
Expand Down Expand Up @@ -853,7 +865,10 @@ def run(self, data_recipe: DataRecipe) -> None:
if not isinstance(user_items, list):
raise ValueError("The `prepare_structure` should return a list of item metadata.")

if self.weights is not None:
if self.reader:
workers_user_items = self.reader.items_to_workers(user_items, self.num_workers)

elif self.weights is not None:
if len(self.weights) != len(user_items):
raise ValueError("The provided weights length should match the inputs' length.")
workers_user_items = _map_items_to_workers_weighted(
Expand All @@ -880,7 +895,7 @@ def run(self, data_recipe: DataRecipe) -> None:

self._cleanup_cache()

print(f"Starting {self.num_workers} workers")
print(f"Starting {self.num_workers} workers with {num_items} items.")

if self.input_dir is None and self.src_resolver is not None and self.input_dir:
self.input_dir = self.src_resolver(self.input_dir)
Expand Down Expand Up @@ -988,6 +1003,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.num_downloaders,
self.num_uploaders,
self.delete_cached_files,
self.reader,
)
worker.start()
workers.append(worker)
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/data/streaming/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch

from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.resolver import (
Expand Down Expand Up @@ -157,6 +158,7 @@ def map(
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
error_when_not_empty: bool = False,
reader: Optional[BaseReader] = None,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.

Expand Down Expand Up @@ -203,6 +205,7 @@ def map(
num_downloaders=num_downloaders,
reorder_files=reorder_files,
weights=weights,
reader=reader,
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
Expand All @@ -225,6 +228,7 @@ def optimize(
machine: Optional[str] = None,
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
reader: Optional[BaseReader] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.

Expand Down Expand Up @@ -274,6 +278,7 @@ def optimize(
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
reorder_files=reorder_files,
reader=reader,
)
return data_processor.run(
LambdaDataChunkRecipe(
Expand Down
Empty file.
Loading
Loading