From fd5a4a358900b3c1b28b09d0b30cde981a063106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Wed, 20 May 2020 12:58:40 -0400 Subject: [PATCH] Fix race in Scanner::ToTable --- cpp/src/arrow/dataset/scanner.cc | 38 ++++++++++++++++------------ python/pyarrow/tests/test_dataset.py | 7 ++--- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 97b2daf9e6137..cf473b94ae5ad 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -177,32 +177,38 @@ static inline RecordBatchVector FlattenRecordBatchVector( return flattened; } +struct TableAssemblyState { + /// Protecting mutating accesses to batches + std::mutex mutex{}; + std::vector batches{}; + + void Emplace(RecordBatchVector b, size_t position) { + std::lock_guard lock(mutex); + if (batches.size() <= position) { + batches.resize(position + 1); + } + batches[position] = std::move(b); + } +}; + Result> Scanner::ToTable() { ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan()); auto task_group = scan_context_->TaskGroup(); - // Protecting mutating accesses to batches - std::mutex mutex; - std::vector batches; + /// Wraps the state in a shared_ptr to ensure that a failing ScanTask don't + /// invalidate the concurrent running tasks because Finish() early returns + /// and the mutex/batches may got out of scope. + auto state = std::make_shared(); + size_t scan_task_id = 0; for (auto maybe_scan_task : scan_task_it) { ARROW_ASSIGN_OR_RAISE(auto scan_task, std::move(maybe_scan_task)); auto id = scan_task_id++; - task_group->Append([&batches, &mutex, id, scan_task] { + task_group->Append([state, id, scan_task] { ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute()); - ARROW_ASSIGN_OR_RAISE(auto local, batch_it.ToVector()); - - { - // Move into global batches. - std::lock_guard lock(mutex); - if (batches.size() <= id) { - batches.resize(id + 1); - } - batches[id] = std::move(local); - } - + state->Emplace(std::move(local), id); return Status::OK(); }); } @@ -211,7 +217,7 @@ Result> Scanner::ToTable() { RETURN_NOT_OK(task_group->Finish()); return Table::FromRecordBatches(scan_options_->schema(), - FlattenRecordBatchVector(std::move(batches))); + FlattenRecordBatchVector(std::move(state->batches))); } } // namespace dataset diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 28c2c3a27795d..82991788dad47 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1485,11 +1485,8 @@ def test_parquet_dataset_factory_invalid(tempdir): dataset = ds.parquet_dataset(str(root_path / '_metadata')) assert dataset.schema.equals(table.schema) assert len(dataset.files) == 4 - # TODO this segfaults with - # terminate called after throwing an instance of 'std::system_error' - # what(): Invalid argument - # with pytest.raises(ValueError): - # dataset.to_table() + with pytest.raises(FileNotFoundError): + dataset.to_table() def _create_metadata_file(root_path):