From 820f1b617a4f8ccf196803c5e48a4f155c929c4a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 30 May 2024 11:41:03 -0500 Subject: [PATCH] [Runtime] Compatibility with dmlc::Stream API changes (#16998) * [Runtime] Compatibility with dmlc::Stream API changes This commit updates TVM implementations of `dmlc::Stream`. With https://github.com/dmlc/dmlc-core/pull/686, this API now requires the `Write` method to return the number of bytes written. This change allows partial writes to be correctly handled. * Update dmlc-core version * lint fix --- 3rdparty/dmlc-core | 2 +- src/runtime/disco/process_session.cc | 3 ++- src/runtime/disco/threaded_session.cc | 3 ++- src/runtime/file_utils.h | 8 ++++++-- src/runtime/rpc/rpc_endpoint.cc | 8 ++++++-- src/runtime/rpc/rpc_socket_impl.cc | 7 ++----- src/support/base64.h | 5 +++-- src/support/pipe.h | 24 +++++++++++------------- 8 files changed, 33 insertions(+), 27 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 09511cf9fe5f..3031e4a61a98 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 09511cf9fe5ff103900a5eafb50870dc84cc17c8 +Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index b50775877733..179010db8a23 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -113,10 +113,11 @@ class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocolWrite(data, size); } + // write the data to the channel. + size_t Write(const void* data, size_t size) final { + writer_->Write(data, size); + return size; + } + // Number of pending bytes requests size_t pending_request_bytes_{0}; // The ring buffer to read data from. diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 1d0b5d5470c8..6882ba4deda9 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -159,11 +159,8 @@ class SimpleSockHandler : public dmlc::Stream { // Internal supporting. // Override methods that inherited from dmlc::Stream. private: - size_t Read(void* data, size_t size) final { - ICHECK_EQ(sock_.RecvAll(data, size), size); - return size; - } - void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); } + size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); } + size_t Write(const void* data, size_t size) final { return sock_.Send(data, size); } // Things of current class. private: diff --git a/src/support/base64.h b/src/support/base64.h index aba4197bce20..2bfc42c27fb1 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -206,7 +206,7 @@ class Base64InStream : public dmlc::Stream { } return size - tlen; } - virtual void Write(const void* ptr, size_t size) { + size_t Write(const void* ptr, size_t size) final { LOG(FATAL) << "Base64InStream do not support write"; } @@ -229,7 +229,7 @@ class Base64OutStream : public dmlc::Stream { using dmlc::Stream::Write; - void Write(const void* ptr, size_t size) final { + size_t Write(const void* ptr, size_t size) final { using base64::EncodeTable; size_t tlen = size; const unsigned char* cptr = static_cast(ptr); @@ -247,6 +247,7 @@ class Base64OutStream : public dmlc::Stream { buf__top_ = 0; } } + return size; } virtual size_t Read(void* ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; diff --git a/src/support/pipe.h b/src/support/pipe.h index 7251a6f14ae2..9d5aa1e48643 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -112,8 +112,8 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - void Write(const void* ptr, size_t size) final { - if (size == 0) return; + size_t Write(const void* ptr, size_t size) final { + if (size == 0) return 0; #ifdef _WIN32 auto fwrite = [&]() -> ssize_t { DWORD nwrite; @@ -124,18 +124,16 @@ class Pipe : public dmlc::Stream { DWORD nwrite = static_cast(RetryCallOnEINTR(fwrite, GetLastErrorCode)); ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); #else - while (size) { - ssize_t nwrite = - RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); - - ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe"; - ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " - << "but only expected to write " << size << " bytes"; - size -= nwrite; - ptr = static_cast(ptr) + nwrite; - } + ssize_t nwrite = + RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); + ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); + + ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " + << "but only expected to write " << size << " bytes"; + #endif + + return nwrite; } /*! * \brief Flush the pipe;