Skip to content

Commit

Permalink
Add support for parallelizing processing parquet files across workers…
Browse files Browse the repository at this point in the history
… and nodes. (#19400)
  • Loading branch information
tchaton authored Feb 5, 2024
1 parent 2778692 commit 7dfc279
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 11 deletions.
2 changes: 2 additions & 0 deletions requirements/data/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ pytest-timeout ==2.1.0
pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
viztracer
pyarrow
polars
Empty file.
47 changes: 47 additions & 0 deletions src/lightning/data/processing/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import io
from typing import Optional, Tuple

from lightning_utilities.core.imports import RequirementCache

_HTTPX_AVAILABLE = RequirementCache("httpx")

# Credit to the https://github.com/rom1504/img2dataset Github repo
# The code was taken from there. It has a MIT License.

def _download_image(
url: str,
timeout: int = 10,
user_agent_token: str = "pytorch-lightning",
) -> Tuple[Optional[io.BytesIO], Optional[Exception]]:
"""Download an image with urllib."""
url
img_stream = None
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)"
import httpx

try:
with httpx.Client(http2=True) as client:
r = client.get(url, headers={"User-Agent": user_agent_string}, timeout=timeout)
img_stream = io.BytesIO(r.read())
return img_stream, None
except Exception as err: # pylint: disable=broad-except
if img_stream is not None:
img_stream.close()
return None, err


def download_image(
url: str,
retries: int = 0,
timeout: int = 10,
user_agent_token: str = "pytorch-lightning",
) -> Tuple[Optional[io.BytesIO], Optional[Exception]]:
if not _HTTPX_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install httpx`.")
for _ in range(retries + 1):
img_stream, err = _download_image(url, timeout, user_agent_token)
if img_stream is not None:
return img_stream, err
return None, err
131 changes: 131 additions & 0 deletions src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional

from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks
from lightning.data.utilities.env import _DistributedEnv

_POLARS_AVAILABLE = RequirementCache("polars")
_PYARROW_AVAILABLE = RequirementCache("pyarrow")


class BaseReader(ABC):

def get_num_nodes(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1))

def get_node_rank(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))

@abstractmethod
def items_to_workers(self, items: List[Any], num_workers: int) -> List[List[Any]]:
"""This method is meant to convert the items provided by the users into items to be processed by the
workers."""
pass

@abstractmethod
def read(self, item: Any) -> Any:
"""Read the data associated to an item."""
pass


@dataclass
class ParquetSlice:
"""Keep track of a parquet file slice with its filepath, start and end."""
filepath: str
start: int
end: int


class ParquetReader(BaseReader):

def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None:
self.num_rows = num_rows
self.to_pandas = to_pandas

if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install pyarrow polars`")

def _get_num_rows(self, path: str) -> int:
if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds
df = ds.dataset(path).scanner()
return df.count_rows()

# FIXED: There is a bug in polars. This leads to read_parquet to hang.
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(path)
num_rows = df.select(pol.len()).collect().item()
return num_rows

raise RuntimeError("Please, install either pyarrow or polars.")

def read(self, item: ParquetSlice) -> Any:
if _POLARS_AVAILABLE:
import polars as pol
df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect()

if self.to_pandas:
df = df.to_pandas()

return df

if _PYARROW_AVAILABLE:
import pyarrow.dataset as ds

df = ds.dataset(item.filepath).scanner()

df = df.take([item.start, item.end])

if self.to_pandas:
df.to_pandas()

return df

raise RuntimeError("Please, install either pyarrow or polars.")


def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]:
intervals = [(0, self._get_num_rows(item)) for item in items]

world_size = self.get_num_nodes() * num_workers
node_rank = self.get_node_rank()

fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes())
parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks(
fake_distributed_env, list(range(len(items))), intervals, False)

workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)]

iterator = enumerate(zip(parquet_indexes_per_worker, p_slices_per_worker))

node_start = node_rank * num_workers
node_end = (node_rank + 1) * num_workers

for worker_idx, (parquet_indexes, p_slices) in iterator:
if node_start <= worker_idx < node_end:
if self.num_rows:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(
items[parquet_index], p_slice_start, p_slice_start + self.num_rows
if p_slice[1] > (p_slice_start + self.num_rows) else
p_slice[1]
)
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows)
if p_slice_start < p_slice[1]
])
else:
workers_user_items[worker_idx % num_workers].extend([
ParquetSlice(items[parquet_index], *p_slice)
for parquet_index, p_slice in zip(parquet_indexes, p_slices)
])

assert len(workers_user_items) == num_workers
assert all(len(w) for w in workers_user_items)

return workers_user_items
38 changes: 27 additions & 11 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
Expand Down Expand Up @@ -158,8 +159,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)

elif os.path.isfile(path):
os.makedirs(os.path.dirname(local_path), exist_ok=True)
shutil.copyfile(path, local_path)
if not path.startswith("/teamspace/studios/this_studio"):
os.makedirs(os.path.dirname(local_path), exist_ok=True)
shutil.copyfile(path, local_path)
else:
raise ValueError(f"The provided {input_dir.url} isn't supported.")

Expand Down Expand Up @@ -340,6 +342,7 @@ def __init__(
num_downloaders: int,
num_uploaders: int,
remove: bool,
reader: Optional[BaseReader] = None,
) -> None:
"""The BaseWorker is responsible to process the user data."""
self.worker_index = worker_index
Expand All @@ -353,6 +356,7 @@ def __init__(
self.num_downloaders = num_downloaders
self.num_uploaders = num_uploaders
self.remove = remove
self.reader = reader
self.paths: List[List[str]] = []
self.remover: Optional[Process] = None
self.downloaders: List[Process] = []
Expand Down Expand Up @@ -433,7 +437,7 @@ def _loop(self) -> None:
self.progress_queue.put((self.worker_index, self._counter))
self._last_time = time()

if self.remove and self.input_dir.path is not None:
if self.remove and self.input_dir.path is not None and self.reader is None:
self.remove_queue.put(self.paths[index])

try:
Expand Down Expand Up @@ -476,7 +480,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
self.to_upload_queues[self._counter % self.num_uploaders].put(data)

def _collect_paths(self) -> None:
if self.input_dir.path is None:
if self.input_dir.path is None or self.reader is not None:
for index in range(len(self.items)):
self.ready_to_process_queue.put(index)
for _ in range(self.num_downloaders):
Expand Down Expand Up @@ -513,7 +517,7 @@ def is_path(element: Any) -> bool:
paths = []
for index, path in indexed_paths.items():
paths.append(path)
if self.input_dir:
if self.input_dir and not self.input_dir.path.startswith("/teamspace/studios/this_studio"):
path = path.replace(self.input_dir.path, self.cache_data_dir)
flattened_item[index] = path

Expand All @@ -525,8 +529,9 @@ def is_path(element: Any) -> bool:
self.items = items

def _start_downloaders(self) -> None:
if self.input_dir.path is None:
if self.input_dir.path is None or self.reader is not None:
return

for _ in range(self.num_downloaders):
to_download_queue: Queue = Queue()
p = Process(
Expand Down Expand Up @@ -583,7 +588,7 @@ def _start_uploaders(self) -> None:

def _handle_data_chunk_recipe(self, index: int) -> None:
try:
self._current_item = self.items[index]
self._current_item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data_or_generator = self.data_recipe.prepare_item(self._current_item)
if isinstance(item_data_or_generator, types.GeneratorType):
for item_data in item_data_or_generator:
Expand All @@ -596,7 +601,7 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
self._try_upload(chunk_filepath)
self._index_counter += 1
except Exception as e:
raise RuntimeError(f"Failed processing {self._current_item}") from e
raise RuntimeError(f"Failed processing {self.items[index]}") from e

def _handle_data_chunk_recipe_end(self) -> None:
chunks_filepaths = self.cache.done()
Expand All @@ -609,7 +614,8 @@ def _handle_data_chunk_recipe_end(self) -> None:
def _handle_data_transform_recipe(self, index: int) -> None:
# Don't use a context manager to avoid deleting files that are being uploaded.
output_dir = tempfile.mkdtemp()
item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir), len(self.items) - 1 == index)
item = self.items[index] if self.reader is None else self.reader.read(self.items[index])
item_data = self.data_recipe.prepare_item(item, str(output_dir), len(self.items) - 1 == index)
if item_data is not None:
raise ValueError(
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
Expand Down Expand Up @@ -792,6 +798,7 @@ def __init__(
random_seed: Optional[int] = 42,
reorder_files: bool = True,
weights: Optional[List[int]] = None,
reader: Optional[BaseReader] = None,
):
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
training faster.
Expand All @@ -809,6 +816,7 @@ def __init__(
Set this to ``False`` if the order in which samples are processed should be preserved.
weights: Provide a list of weights associated to the inputs.
This is used to evenly split the work among the workers.
reader: Map the inputs to worker inputs and provides a read method to read a slice of the data.
"""
self.input_dir = _resolve_dir(input_dir)
Expand All @@ -825,6 +833,10 @@ def __init__(
self.stop_queues: List[Queue] = []
self.reorder_files = reorder_files
self.weights = weights
self.reader = reader

if self.reader is not None and self.weights is not None:
raise ValueError("Either the reader or the weights needs to be defined.")

# Ensure the input dir is the same across all nodes
self.input_dir = broadcast_object("input_dir", self.input_dir)
Expand Down Expand Up @@ -853,7 +865,10 @@ def run(self, data_recipe: DataRecipe) -> None:
if not isinstance(user_items, list):
raise ValueError("The `prepare_structure` should return a list of item metadata.")

if self.weights is not None:
if self.reader:
workers_user_items = self.reader.items_to_workers(user_items, self.num_workers)

elif self.weights is not None:
if len(self.weights) != len(user_items):
raise ValueError("The provided weights length should match the inputs' length.")
workers_user_items = _map_items_to_workers_weighted(
Expand All @@ -880,7 +895,7 @@ def run(self, data_recipe: DataRecipe) -> None:

self._cleanup_cache()

print(f"Starting {self.num_workers} workers")
print(f"Starting {self.num_workers} workers with {num_items} items.")

if self.input_dir is None and self.src_resolver is not None and self.input_dir:
self.input_dir = self.src_resolver(self.input_dir)
Expand Down Expand Up @@ -988,6 +1003,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.num_downloaders,
self.num_uploaders,
self.delete_cached_files,
self.reader,
)
worker.start()
workers.append(worker)
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/data/streaming/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch

from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.streaming.resolver import (
Expand Down Expand Up @@ -157,6 +158,7 @@ def map(
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
error_when_not_empty: bool = False,
reader: Optional[BaseReader] = None,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.
Expand Down Expand Up @@ -203,6 +205,7 @@ def map(
num_downloaders=num_downloaders,
reorder_files=reorder_files,
weights=weights,
reader=reader,
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
Expand All @@ -225,6 +228,7 @@ def optimize(
machine: Optional[str] = None,
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
reader: Optional[BaseReader] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
Expand Down Expand Up @@ -274,6 +278,7 @@ def optimize(
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
reorder_files=reorder_files,
reader=reader,
)
return data_processor.run(
LambdaDataChunkRecipe(
Expand Down
Empty file.
Loading

0 comments on commit 7dfc279

Please sign in to comment.