From 8e4f9f7c83c4b0c7ada697d9b8fc4b064ea3899e Mon Sep 17 00:00:00 2001 From: DaGuT Date: Tue, 8 Sep 2020 13:50:38 +0300 Subject: [PATCH 1/5] Fix max_depth recursion crash in AsynchronousLoader Some people use dict samples with strings, numpy array etc. They are returned from dataloader and are passed to model. Later those are parsed/replaced/removed inside of step/forward function. Example of such dict (I use it): `{'image': tensor([WHATEVER]), 'path': [WHATEVER_STRING], 'target': tensor([WHATEVER_NUMPY]), 'meta': {'PatientAge': tensor([WHATEVER]), 'PatientSex': tensor([WHATEVER]), 'StudyDate': tensor([WHATEVER])}, 'mask': tensor([WHATEVER]), 'bboxes': [[tensor([WHATEVER]), tensor([WHATEVER]), tensor([WHATEVER])]]}` Sadly, async loader was crashing as max depth of recursion was reached (e.g. when it meets string, it just infinitely goes to else in original AsyncLoader). I went to torch.utils.data.default_collate (as it's said that code is based on it), took few lines from there and it works now. So it will no longer crash at named_tuples, dict-like objects and strings. Those will all be processed. --- pl_bolts/datamodules/async_dataloader.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 6527d39f2b..05fca59c4c 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,12 @@ from queue import Queue from threading import Thread +import re + import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset - +from torch._six import container_abcs, string_classes, int_classes class AsynchronousLoader(object): """ @@ -54,14 +56,30 @@ def load_loop(self): # The loop that will load into the queue in the background # Recursive loading for each instance based on torch.utils.data.default_collate def load_instance(self, sample): + np_str_obj_array_pattern = re.compile(r'[SaUO]') + + elem_type = type(sample) + if torch.is_tensor(sample): with torch.cuda.stream(self.load_stream): # Can only do asynchronous transfer if we use pin_memory if not sample.is_pinned(): sample = sample.pin_memory() return sample.to(self.device, non_blocking=True) - else: + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(sample.dtype.str) is not None: + return self.load_instance(sample) + return self.load_instance(torch.as_tensor(sample)) + elif isinstance(sample, container_abcs.Mapping): + return {key: self.load_instance(sample[key]) for key in sample} + elif isinstance(sample, tuple) and hasattr(sample, '_fields'): # namedtuple + return elem_type(*(self.load_instance(d) for d in sample)) + elif isinstance(sample, container_abcs.Sequence) and not isinstance(sample, string_classes): return [self.load_instance(s) for s in sample] + else: + return sample def __iter__(self): # We don't want to run the thread more than once From 130d397494999e72b25889ab1563836a4b93f653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BB=D0=B5=D0=BA=D1=81=D0=B0=D0=BD=D0=B4=D1=80=20?= =?UTF-8?q?=D0=9C=D0=BE=D0=BD=D0=B3=D0=BE=D0=BB=D0=B8=D0=BD?= Date: Tue, 8 Sep 2020 14:17:47 +0300 Subject: [PATCH 2/5] made "else" as it was before --- pl_bolts/datamodules/async_dataloader.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 05fca59c4c..3d71f56479 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -8,6 +8,7 @@ from torch.utils.data import Dataset from torch._six import container_abcs, string_classes, int_classes + class AsynchronousLoader(object): """ Class for asynchronously loading from CPU memory to device memory with DataLoader. @@ -38,7 +39,8 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches= elif hasattr(self.dataloader, '__len__'): self.num_batches = len(self.dataloader) else: - raise Exception("num_batches must be specified or data must have finite __len__") + raise Exception( + "num_batches must be specified or data must have finite __len__") self.device = device self.q_size = q_size @@ -76,10 +78,10 @@ def load_instance(self, sample): return {key: self.load_instance(sample[key]) for key in sample} elif isinstance(sample, tuple) and hasattr(sample, '_fields'): # namedtuple return elem_type(*(self.load_instance(d) for d in sample)) - elif isinstance(sample, container_abcs.Sequence) and not isinstance(sample, string_classes): - return [self.load_instance(s) for s in sample] - else: + elif isinstance(sample, string_classes): return sample + else: + return [self.load_instance(s) for s in sample] def __iter__(self): # We don't want to run the thread more than once @@ -107,4 +109,4 @@ def __next__(self): return out def __len__(self): - return self.num_batches + return self.num_batches \ No newline at end of file From 4f539b274e090b564b5b46dfc6157319b7b5bd04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BB=D0=B5=D0=BA=D1=81=D0=B0=D0=BD=D0=B4=D1=80=20?= =?UTF-8?q?=D0=9C=D0=BE=D0=BD=D0=B3=D0=BE=D0=BB=D0=B8=D0=BD?= Date: Tue, 8 Sep 2020 14:33:15 +0300 Subject: [PATCH 3/5] Revert "made "else" as it was before" This reverts commit 130d397494999e72b25889ab1563836a4b93f653. --- pl_bolts/datamodules/async_dataloader.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 3d71f56479..05fca59c4c 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -8,7 +8,6 @@ from torch.utils.data import Dataset from torch._six import container_abcs, string_classes, int_classes - class AsynchronousLoader(object): """ Class for asynchronously loading from CPU memory to device memory with DataLoader. @@ -39,8 +38,7 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches= elif hasattr(self.dataloader, '__len__'): self.num_batches = len(self.dataloader) else: - raise Exception( - "num_batches must be specified or data must have finite __len__") + raise Exception("num_batches must be specified or data must have finite __len__") self.device = device self.q_size = q_size @@ -78,10 +76,10 @@ def load_instance(self, sample): return {key: self.load_instance(sample[key]) for key in sample} elif isinstance(sample, tuple) and hasattr(sample, '_fields'): # namedtuple return elem_type(*(self.load_instance(d) for d in sample)) - elif isinstance(sample, string_classes): - return sample - else: + elif isinstance(sample, container_abcs.Sequence) and not isinstance(sample, string_classes): return [self.load_instance(s) for s in sample] + else: + return sample def __iter__(self): # We don't want to run the thread more than once @@ -109,4 +107,4 @@ def __next__(self): return out def __len__(self): - return self.num_batches \ No newline at end of file + return self.num_batches From 2ba3bd1f7d202c64a1069fcaab941305178092b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BB=D0=B5=D0=BA=D1=81=D0=B0=D0=BD=D0=B4=D1=80=20?= =?UTF-8?q?=D0=9C=D0=BE=D0=BD=D0=B3=D0=BE=D0=BB=D0=B8=D0=BD?= Date: Tue, 8 Sep 2020 14:38:01 +0300 Subject: [PATCH 4/5] autopep8 format --- pl_bolts/datamodules/async_dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 05fca59c4c..17e46e6f13 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -8,6 +8,7 @@ from torch.utils.data import Dataset from torch._six import container_abcs, string_classes, int_classes + class AsynchronousLoader(object): """ Class for asynchronously loading from CPU memory to device memory with DataLoader. @@ -15,7 +16,7 @@ class AsynchronousLoader(object): Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or DistributedDataParallel which uses its own code for transferring data across GPUs. This could just break or make things slower with DataParallel or DistributedDataParallel. - + Args: data: The PyTorch Dataset or DataLoader we're using to load. device: The PyTorch device we are loading to From d04a41e9a60b0698bc5810668d9ffce6ce8517af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BB=D0=B5=D0=BA=D1=81=D0=B0=D0=BD=D0=B4=D1=80=20?= =?UTF-8?q?=D0=9C=D0=BE=D0=BD=D0=B3=D0=BE=D0=BB=D0=B8=D0=BD?= Date: Tue, 8 Sep 2020 14:52:32 +0300 Subject: [PATCH 5/5] moved np_str_obj_array_pattern definition to init --- pl_bolts/datamodules/async_dataloader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 17e46e6f13..82abfb5223 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -16,7 +16,7 @@ class AsynchronousLoader(object): Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or DistributedDataParallel which uses its own code for transferring data across GPUs. This could just break or make things slower with DataParallel or DistributedDataParallel. - + Args: data: The PyTorch Dataset or DataLoader we're using to load. device: The PyTorch device we are loading to @@ -49,6 +49,8 @@ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches= self.idx = 0 + self.np_str_obj_array_pattern = re.compile(r'[SaUO]') + def load_loop(self): # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) @@ -57,8 +59,6 @@ def load_loop(self): # The loop that will load into the queue in the background # Recursive loading for each instance based on torch.utils.data.default_collate def load_instance(self, sample): - np_str_obj_array_pattern = re.compile(r'[SaUO]') - elem_type = type(sample) if torch.is_tensor(sample): @@ -70,7 +70,7 @@ def load_instance(self, sample): elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' \ - and np_str_obj_array_pattern.search(sample.dtype.str) is not None: + and self.np_str_obj_array_pattern.search(sample.dtype.str) is not None: return self.load_instance(sample) return self.load_instance(torch.as_tensor(sample)) elif isinstance(sample, container_abcs.Mapping):