-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing #919
Conversation
@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False): | |||
if not isinstance(response, communication.messages.ResetIteratorResponse): | |||
raise Exception("Invalid response received") | |||
|
|||
def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we didn't have this response before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When new iteration starts, reset
will be called for _IterateQueueDataPipes
. And, an extra get_respondse_next
is invoked to drop / response. So, all requests are served.
data/torchdata/dataloader2/reading_service.py
Lines 152 to 154 in 0a0ae5d
for dp in self.datapipes: | |
if dp.protocol.waiting_for_response(): | |
dp.protocol.get_response_next(block=True) |
I can give a try to remove this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems removing is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question: any docs / entry pointer to understand how these dataloader2/communication
design / work? ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no doc. I can talk about the components based on my understanding:
- ProtocolClient remains in the main process to pass
Request
via therequest_queue
to the corresponding worker process - ProtocolServer is created in the worker process that takes request then send
Response
back to main process viareqponse_queue
- DataPipeBehindQueues is the worker loop that holds a
ProtocolServer
to maniputlateDataPipe
based on theReqeust
- QueueWrapper is the
DataPipe
that holds aProtocolClient
instance to issueRequest
and yield data fromresponse_queue
to the subsequentDataPipe
graph.
We can talk about more detail offline if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try to draw it with mermaid
graph TD;
Worker_1_1-->ProtocolServer_1_1-->ProtocolClient_1
Worker_1_2-->ProtocolServer_1_2-->ProtocolClient_1
Worker_1_3-->ProtocolServer_1_3-->ProtocolClient_1
ProtocolClient_1-->GPU1
ProtocolClient_2-->GPU2
ProtocolClient_3-->GPU3
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going to need more time to look through this. What is the use case of non-shardable DataPipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should provide a definition of non-shardable DataPipe and why that may occur (an abstract example would be helpful as well). In particularly, as you mentioned, note that we do not want to duplicate such DataPipe into multiple workers and sharding filter should not be applied to it. Instead it should be read round-robin (or something else) by downstream DataPipes that are sharded?
Make sense. I will add doc regarding non-shardbable DataPipe/shardable DataPipe to the documents for dataloader2.
Actually, this is not True. Edit: Updated the summary with a few topics need to be covered. Let me know if there is any other concern on the documentation. |
@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False): | |||
if not isinstance(response, communication.messages.ResetIteratorResponse): | |||
raise Exception("Invalid response received") | |||
|
|||
def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses
@@ -84,9 +153,22 @@ def process_reset_fn( | |||
reset the random state of the ``DataPipe`` graph and the global random states for ``torch``, | |||
``random`` and ``numpy``. | |||
""" | |||
# Reset non-sharding process first | |||
graph = traverse_dps(datapipe) | |||
non_sharding_process_dps = find_dps(graph, communication.iter._IterateQueueDataPipes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Is it always the case where _IterateQueueDataPipes
is the only non-sharding process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes because we will find the lowest common ancestor of all non-shardable DataPipes in the main process and replace it by this _IterateQueueDataPipes
in the worker process.
So, it's guaranteed that there is only a single non-sharding process.
@NivekT
I propose to change to find lowest common ancestor of non-replicable Data Source and sent them to non-sharding process. graph TD;
DP1(non-replicable DP1)-->DP2;
DP2-->DP5;
DP3(non-replicable DP3)-->DP4;
DP4-->DP5;
DP5-->DP6;
DP6-->fullsync;
fullsync-->output;
The lowest common ancestor of all non-shardable Data Source is |
That makes sense. How do you plan to implement that? Will we still have users calling |
We should allow either users calling |
Replying to #919 :
An alternative approach for distributed sharding would be distribute the workload based on filename or some compression/encoding unit in file (in Parquet it's called "Page": https://parquet.apache.org/docs/concepts/) and in ORC I think it's called "Stripe". So it avoid reading the original data multiple times? |
@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False): | |||
if not isinstance(response, communication.messages.ResetIteratorResponse): | |||
raise Exception("Invalid response received") | |||
|
|||
def get_response_reset_epoch(self, block=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question: any docs / entry pointer to understand how these dataloader2/communication
design / work? ~
Non-shardable is an extremely bad name as it's actually being sharded by round-robin dispatching. The actual meaning here is to prevent copy of DataPipe to multiple processes. I might rename it to non-replicable DataPipe/dispatching process. |
I agree with renaming "non-shardable" to "non-replicated" DataPipe. I suppose sometimes it is replicable but the users won't want to? |
@wenleix @NivekT This PR has been updated. And updated document can be found in https://ejguan.github.io/dataloader2.html#dynamic-sharding |
# Lazily import to prevent circular import | ||
from torchdata.dataloader2 import communication | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my temporary fix for the circular import problem.
cc: @NivekT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following review steps are added to make it easier to review the main logic in the PR. Let me know if there is anything fuzzy to you.
@functional_datapipe("sharding_round_robin_dispatch") | ||
class ShardingRoundRobinDispatcherIterDataPipe(IterDataPipe): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 1: Add ShardingRoundRobinDispatcher
is introduced to indicate where the pipeline should be non-replicable.
I am open to any suggestion on the name/functional name
def __iter__(self) -> Iterator[T_co]: | ||
yield from self.source_datapipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 1.1: Keep __iter__
as a noop here rather than raising Error to support single-process use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by "single-process use case". does it mean "eager mode"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, pure eager currently. In the future, we might provide a by-default SingleProcessReadingService
for users.
res_queue: Queue | ||
|
||
|
||
def find_lca_non_replicable_dp(graph: DataPipeGraph) -> Optional[DataPipe]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 2: Add this graph function to find the lowest common ancestor of the non-replicable DataPipes (ShardingRoundRobinDispatcher
)
graph = traverse_dps(end_dp) | ||
return single_br_dp, multi_br_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph | ||
|
||
def test_single_non_replicable_dp(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 3.1: Tests for single non-replicable DataPipe
graph, cir_map_dp = make_dp_non_replicable(graph, cir_map_dp) | ||
self.assertEqual(find_lca_non_replicable_dp(graph), cir_map_dp) | ||
|
||
def test_multi_non_replicable_dps(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 3.2: Tests for multiple non-replicable DataPipes
@@ -91,3 +131,20 @@ def SpawnThreadForDataPipeline(datapipe): | |||
|
|||
process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True) | |||
return process, req_queue, res_queue, new_datapipe | |||
|
|||
|
|||
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 7.1: Create num_workers
pairs of req_queue and res_queue.
And launch MultipleDataPipesToQueuesLoop
to iterate over the non-replicable DataPipe
] | ||
|
||
|
||
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review 7.2: Launch a non-blocking DataPipeBehindQueues
while-loop per child DataPipe from round_robin_demux
.
Using zip_longest
to mimic round robin calling next
over each child DataPipe.
# Dispatching process for non-replicable DataPipes exists | ||
if self._dispatch_process is not None: | ||
# Use the placehold to pass request/response queue to each worker process | ||
dummy_dp.req_queue = self._dispatch_process[1][worker_id] | ||
dummy_dp.res_queue = self._dispatch_process[2][worker_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review 6.2: We only have one _DummyIterDataPipe
in the main process but have num_workers
pairs of req_queue
and res_queue
. To connect a pair to the corresponding worker process, inject the attributes from _DummyIterDataPipe
before sending it to the worker process.
# Find if there is non-replicable DataPipe | ||
graph = traverse_dps(datapipe) | ||
non_replicable_dp = find_dps(graph, _DummyIterDataPipe) # type: ignore | ||
|
||
# There are two cases for DataPipe graph in terms of mp sharding: | ||
# 1) All DataPipes are replicable, apply mp sharding to the whole graph | ||
if len(non_replicable_dp) == 0: | ||
torch.utils.data.graph_settings.apply_sharding( | ||
datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING | ||
) | ||
# 2) There is non-replicable DataPipe. Since we have replaced the lowest common | ||
# ancestor by a `_DummyIterDataPipe`, we would only apply mp sharding | ||
# to replicable branches that don't have `_DummyIterDataPipe`. | ||
else: | ||
assert len(non_replicable_dp) == 1 | ||
replicable_branches = find_replicable_branches(graph) | ||
for dp in replicable_branches: | ||
torch.utils.data.graph_settings.apply_sharding( | ||
dp, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING | ||
) | ||
|
||
req_queue = non_replicable_dp[0].req_queue | ||
res_queue = non_replicable_dp[0].res_queue | ||
|
||
queue_wrapper = communication.iter.QueueWrapper( | ||
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) | ||
) | ||
dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper]) | ||
graph = replace_dp(graph, non_replicable_dp[0], dispatch_process_dp) | ||
datapipe = list(graph.values())[0][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 8: In the worker process, find if there is _DummyIterDataPipe
.
If not, it means the whole pipeline is replicable and do the sharding by filter
If there is, we would do sharding only over the replicable branches.
QueueWrapper
and _IterateQueueDataPipes
is used to wrap res_queue
and req_queue
as a DataPipe that can handle Request
and Response
based on the protocol.
def dispatch_process_reset_fn( | ||
datapipe: DataPipe, | ||
worker_info: WorkerInfo, | ||
dist_info: _DistInfo, | ||
) -> DataPipe: | ||
r""" | ||
Based on the distributed shared random seed, this function is used to set the random state | ||
of the non-repliable ``DataPipe`` graph and the global random states for the dispatch process. | ||
This function would guarantee that all distributed non-sharding process share the | ||
same random states to ensure the same shuffle order. | ||
""" | ||
worker_seed_generator = torch.Generator() | ||
worker_seed_generator.manual_seed(dist_info.shared_seed) | ||
torch.utils.data.graph_settings.apply_random_seed( | ||
datapipe, | ||
worker_seed_generator, | ||
) | ||
|
||
# Set global random states | ||
_set_global_random_state(worker_seed_generator) | ||
|
||
return datapipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 9: When new epoch starts, we want to control the random seed based on distributed information.
We need to guarantee all distributed dispatching process share the same random seed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Review Step 1: Add ShardingRoundRobinDispatcher is introduced to indicate where the pipeline should be non-replicable."
LGTM. Name (sharding_round_robin_dispatch
) is a bit long but let's keep it for now...
…ed non-sharding process
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Review Step 7: create req_queue and res_queueand execute with round_robin_demux"
High-level control flow looks good. But didn't have enough low-level context yet~
res_queues | ||
), "``MultipleDataPipesToQueuesLoop`` requires the same number of datapipes, request queues and response queues" | ||
|
||
torch.set_num_threads(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question: what's this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, this was introduced to disable OpenMP
in dataloader workers. This is because OpenMP would create number of threads that equals to the number of CPU cores by default. And, with multiprocessing enabled, num_workers x num_threads_per_worker
threads will be created. This won't provide any further benefit.
Besides, OpenMP
features should not be enabled if any OpenMP
features are utilized in the main process and before subprocesses are forked.
Any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Review Step 8: Worker process handling (graph rewrite and receiving demux result from dispatch process)"
LGTM.
queue_wrapper = communication.iter.QueueWrapper( | ||
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) | ||
) | ||
dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC:
IterDataPipeQueueProtocolClient
will be wrapped into a QueueWrapper
(but still not a IterDataPipe
), and further wrapped into a _IterateQueueDataPipes
which is a IterDataPIpe
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Both QueueWrapper
and _IterateQueueDataPipes
are IterDataPipe
, this is one of the thing that we can optimize later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Step 9: LGTM % minor question...
@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
yield response.value | ||
|
||
def reset(self): | ||
# NonBlocking DataPipes do not reset automatically, have to do it manually |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ejguan Just noticed that the reset
method has changed after the move. It used to have this:
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)
Is this no longer necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's necessary because we always want to do reset_epoch
as well for NonBlocking
, which will discard all existing requests. So, when at the point of reset
, we should expect no request within the worker process queues.
This PR is created on top of #555. And, this PR extends
PrototypeMultiprocessingReadingService
to accept non-replicable DataPipe.And, this PR depends on pytorch/pytorch#90769
Main Changes
ShardingRoundRobinDispatcher
(functional namesharding_round_robin_dispatch
) to indicate non-replicable DataPipeMultipleDataPipesToQueuesLoop
to connect non-sharding process to request/response queuesfind_lca_non_replicable_dp
as a graph function to determine the lowest common ancestor of all non-replicabble DataPipes. This would guarantee that all non-replicable DataPipes will be running in a single dispatching processfind_replicable_branches
to apply mp sharding to those replicable branches, because all non-replicable branches have been properly sharded by routing data round-robinly to worker processes.ResetEpochResponse
from protocol viaget_response_reset_epoch
Please check the link for doc: https://ejguan.github.io/dataloader2.html#dynamic-sharding
nit Changes
Spawn
toCreate
as the process has not been started