diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index bb088c5653f2..c282eb006f92 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -45,7 +45,7 @@ class PooledAllocator final : public Allocator { ~PooledAllocator() { ReleaseAll(); } Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; auto&& it = memory_pool_.find(size); if (it != memory_pool_.end() && !it->second.empty()) { @@ -57,14 +57,22 @@ class PooledAllocator final : public Allocator { Buffer buf; buf.device = device_; buf.size = size; - buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + try { + buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + used_memory_.fetch_add(size, std::memory_order_relaxed); DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; return buf; } void Free(const Buffer& buffer) override { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); if (memory_pool_.find(buffer.size) == memory_pool_.end()) { memory_pool_.emplace(buffer.size, std::vector{}); } @@ -76,7 +84,7 @@ class PooledAllocator final : public Allocator { private: void ReleaseAll() { - std::lock_guard lock(mu_); + std::lock_guard lock(mu_); for (auto const& it : memory_pool_) { auto const& pool = it.second; for (auto const& buf : pool) { @@ -92,7 +100,7 @@ class PooledAllocator final : public Allocator { size_t page_size_; std::atomic used_memory_; std::unordered_map > memory_pool_; - std::mutex mu_; + std::recursive_mutex mu_; Device device_; };