Skip to content

Commit

Permalink
Add multiple uploaders to the map, optimize (#18989)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 7288302)
  • Loading branch information
tchaton authored and lantiga committed Nov 15, 2023
1 parent 8d0830b commit 69fc0b4
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 47 deletions.
123 changes: 77 additions & 46 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import types
from abc import abstractmethod
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Process, Queue
from queue import Empty
from shutil import copyfile, rmtree
Expand All @@ -15,7 +16,7 @@
from urllib import parse

import torch
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
from lightning.data.streaming import Cache
Expand Down Expand Up @@ -278,6 +279,7 @@ def __init__(
error_queue: Queue,
stop_queue: Queue,
num_downloaders: int,
num_uploaders: int,
remove: bool,
) -> None:
"""The BaseWorker is responsible to process the user data."""
Expand All @@ -290,18 +292,19 @@ def __init__(
self.items = items
self.num_items = len(self.items)
self.num_downloaders = num_downloaders
self.num_uploaders = num_uploaders
self.remove = remove
self.paths: List[List[str]] = []
self.remover: Optional[Process] = None
self.downloaders: List[Process] = []
self.uploaders: List[Process] = []
self.to_download_queues: List[Queue] = []
self.to_upload_queues: List[Queue] = []
self.stop_queue = stop_queue
self.ready_to_process_queue: Queue = Queue()
self.remove_queue: Queue = Queue()
self.upload_queue: Queue = Queue()
self.progress_queue: Queue = progress_queue
self.error_queue: Queue = error_queue
self.uploader: Optional[Process] = None
self._collected_items = 0
self._counter = 0
self._last_time = time()
Expand All @@ -316,14 +319,14 @@ def run(self) -> None:
traceback_format = traceback.format_exc()
print(traceback_format)
self.error_queue.put(traceback_format)
print(f"Worker {self.worker_index} is done.")
print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is done.")

def _setup(self) -> None:
self._set_environ_variables()
self._create_cache()
self._collect_paths()
self._start_downloaders()
self._start_uploader()
self._start_uploaders()
self._start_remover()

def _loop(self) -> None:
Expand All @@ -335,13 +338,19 @@ def _loop(self) -> None:
if index is None:
num_downloader_finished += 1
if num_downloader_finished == self.num_downloaders:
print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is terminating.")

if isinstance(self.data_recipe, DataChunkRecipe):
self._handle_data_chunk_recipe_end()

if self.output_dir.url if self.output_dir.url else self.output_dir.path:
assert self.uploader
self.upload_queue.put(None)
self.uploader.join()
# Inform the uploaders they are doing working
for i in range(self.num_uploaders):
self.to_upload_queues[i].put(None)

# Wait for them all to be finished
for uploader in self.uploaders:
uploader.join()

if self.remove:
assert self.remover
Expand Down Expand Up @@ -402,7 +411,7 @@ def _try_upload(self, filepath: Optional[str]) -> None:
return

assert os.path.exists(filepath), filepath
self.upload_queue.put(filepath)
self.to_upload_queues[self._counter % self.num_uploaders].put(filepath)

def _collect_paths(self) -> None:
items = []
Expand Down Expand Up @@ -475,19 +484,24 @@ def _start_remover(self) -> None:
)
self.remover.start()

def _start_uploader(self) -> None:
def _start_uploaders(self) -> None:
if self.output_dir.path is None and self.output_dir.url is None:
return
self.uploader = Process(
target=_upload_fn,
args=(
self.upload_queue,
self.remove_queue,
self.cache_chunks_dir,
self.output_dir,
),
)
self.uploader.start()

for _ in range(self.num_uploaders):
to_upload_queue: Queue = Queue()
p = Process(
target=_upload_fn,
args=(
to_upload_queue,
self.remove_queue,
self.cache_chunks_dir,
self.output_dir,
),
)
p.start()
self.uploaders.append(p)
self.to_upload_queues.append(to_upload_queue)

def _handle_data_chunk_recipe(self, index: int) -> None:
try:
Expand All @@ -509,10 +523,10 @@ def _handle_data_chunk_recipe(self, index: int) -> None:
def _handle_data_chunk_recipe_end(self) -> None:
chunks_filepaths = self.cache.done()

if chunks_filepaths:
for chunk_filepath in chunks_filepaths:
if chunks_filepaths and len(self.to_upload_queues):
for i, chunk_filepath in enumerate(chunks_filepaths):
if isinstance(chunk_filepath, str) and os.path.exists(chunk_filepath):
self.upload_queue.put(chunk_filepath)
self.to_upload_queues[i % self.num_uploaders].put(chunk_filepath)

def _handle_data_transform_recipe(self, index: int) -> None:
# Don't use a context manager to avoid deleting files that are being uploaded.
Expand Down Expand Up @@ -721,6 +735,7 @@ def __init__(
output_dir: Optional[Union[str, Dir]] = None,
num_workers: Optional[int] = None,
num_downloaders: Optional[int] = None,
num_uploaders: Optional[int] = None,
delete_cached_files: bool = True,
fast_dev_run: Optional[Union[bool, int]] = None,
random_seed: Optional[int] = 42,
Expand All @@ -734,6 +749,7 @@ def __init__(
output_dir: The path to where the output data are stored.
num_workers: The number of worker threads to use.
num_downloaders: The number of file downloaders to use.
num_uploaders: The number of file uploaders to use.
delete_cached_files: Whether to delete the cached files.
fast_dev_run: Whether to run a quick dev run.
random_seed: The random seed to be set before shuffling the data.
Expand All @@ -744,7 +760,8 @@ def __init__(
self.input_dir = _resolve_dir(input_dir)
self.output_dir = _resolve_dir(output_dir)
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
self.num_downloaders = num_downloaders or 1
self.num_downloaders = num_downloaders or 2
self.num_uploaders = num_uploaders or 5
self.delete_cached_files = delete_cached_files
self.fast_dev_run = _get_fast_dev_run() if fast_dev_run is None else fast_dev_run
self.workers: Any = []
Expand Down Expand Up @@ -816,30 +833,43 @@ def run(self, data_recipe: DataRecipe) -> None:

current_total = 0
has_failed = False
with tqdm(total=num_items, smoothing=0, position=-1, mininterval=1) as pbar:
while True:
pbar = _tqdm(
desc="Progress",
total=num_items,
smoothing=0,
position=-1,
mininterval=1,
leave=True,
dynamic_ncols=True,
)

while True:
try:
error = self.error_queue.get(timeout=0.001)
self._exit_on_error(error)
except Empty:
assert self.progress_queue
try:
error = self.error_queue.get(timeout=0.001)
self._exit_on_error(error)
index, counter = self.progress_queue.get(timeout=0.001)
except Empty:
assert self.progress_queue
try:
index, counter = self.progress_queue.get(timeout=0.001)
except Empty:
continue
self.workers_tracker[index] = counter
new_total = sum(self.workers_tracker.values())

pbar.update(new_total - current_total)
current_total = new_total
if current_total == num_items:
break

# Exit early if all the workers are done.
# This means there were some kinda of errors.
if all(not w.is_alive() for w in self.workers):
has_failed = True
break
continue
self.workers_tracker[index] = counter
new_total = sum(self.workers_tracker.values())

pbar.set_postfix({"time": datetime.now().strftime("%H:%M:%S.%f")})
pbar.update(new_total - current_total)

current_total = new_total
if current_total == num_items:
break

# Exit early if all the workers are done.
# This means there were some kinda of errors.
if all(not w.is_alive() for w in self.workers):
has_failed = True
break

pbar.close()

num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
Expand Down Expand Up @@ -896,6 +926,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.error_queue,
stop_queues[-1],
self.num_downloaders,
self.num_uploaders,
self.delete_cached_files,
)
worker.start()
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
delete_cached_files=delete_cached_files,
fast_dev_run=fast_dev_run,
output_dir=remote_output_dir,
num_uploaders=1,
num_downloaders=1,
)
data_processor.run(CustomDataChunkRecipe(chunk_size=2))

Expand All @@ -508,6 +510,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
data_processor = TestDataProcessor(
input_dir=input_dir,
num_workers=2,
num_uploaders=1,
num_downloaders=1,
delete_cached_files=delete_cached_files,
fast_dev_run=fast_dev_run,
Expand Down Expand Up @@ -668,7 +671,6 @@ def test_data_processing_map(monkeypatch, tmpdir):


def optimize_fn(filepath):
print(filepath)
from PIL import Image

return [Image.open(filepath), os.path.basename(filepath)]
Expand Down

0 comments on commit 69fc0b4

Please sign in to comment.