Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Compatibility with dmlc::Stream API changes #16998

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/runtime/disco/process_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol<DiscoP
return size;
}

void Write(const void* data, size_t size) final {
size_t Write(const void* data, size_t size) final {
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
return size;
}

using dmlc::Stream::Read;
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/disco/threaded_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
return size;
}

void Write(const void* data, size_t size) final {
size_t Write(const void* data, size_t size) final {
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
return size;
}

using dmlc::Stream::Read;
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/file_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ struct SimpleBinaryFileStream : public dmlc::Stream {
CHECK(fp_ != nullptr) << "File is closed";
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void* ptr, size_t size) {
virtual size_t Write(const void* ptr, size_t size) {
CHECK(!read_) << "File opened in read-mode, cannot write.";
CHECK(fp_ != nullptr) << "File is closed";
CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "SimpleBinaryFileStream.Write incomplete";
size_t nwrite = std::fwrite(ptr, 1, size, fp_);
int err = std::ferror(fp_);

CHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << std::strerror(err);
return nwrite;
}
inline void Close(void) {
if (fp_ != nullptr) {
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
pending_request_bytes_ -= size;
return size;
}
// wriite the data to the channel.
void Write(const void* data, size_t size) final { writer_->Write(data, size); }
// write the data to the channel.
size_t Write(const void* data, size_t size) final {
writer_->Write(data, size);
return size;
}

// Number of pending bytes requests
size_t pending_request_bytes_{0};
// The ring buffer to read data from.
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/rpc/rpc_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,8 @@ class SimpleSockHandler : public dmlc::Stream {
// Internal supporting.
// Override methods that inherited from dmlc::Stream.
private:
size_t Read(void* data, size_t size) final {
ICHECK_EQ(sock_.RecvAll(data, size), size);
return size;
}
void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); }
size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); }
size_t Write(const void* data, size_t size) final { return sock_.Send(data, size); }

// Things of current class.
private:
Expand Down
5 changes: 3 additions & 2 deletions src/support/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Base64InStream : public dmlc::Stream {
}
return size - tlen;
}
virtual void Write(const void* ptr, size_t size) {
size_t Write(const void* ptr, size_t size) final {
LOG(FATAL) << "Base64InStream do not support write";
}

Expand All @@ -229,7 +229,7 @@ class Base64OutStream : public dmlc::Stream {

using dmlc::Stream::Write;

void Write(const void* ptr, size_t size) final {
size_t Write(const void* ptr, size_t size) final {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
Expand All @@ -247,6 +247,7 @@ class Base64OutStream : public dmlc::Stream {
buf__top_ = 0;
}
}
return size;
}
virtual size_t Read(void* ptr, size_t size) {
LOG(FATAL) << "Base64OutStream do not support read";
Expand Down
24 changes: 11 additions & 13 deletions src/support/pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class Pipe : public dmlc::Stream {
* \param size block size
* \return the size of data read
*/
void Write(const void* ptr, size_t size) final {
if (size == 0) return;
size_t Write(const void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
auto fwrite = [&]() -> ssize_t {
DWORD nwrite;
Expand All @@ -124,18 +124,16 @@ class Pipe : public dmlc::Stream {
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();
#else
while (size) {
ssize_t nwrite =
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);

ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe";
ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
<< "but only expected to write " << size << " bytes";
size -= nwrite;
ptr = static_cast<const char*>(ptr) + nwrite;
}
ssize_t nwrite =
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);

ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
<< "but only expected to write " << size << " bytes";

#endif

return nwrite;
}
/*!
* \brief Flush the pipe;
Expand Down
Loading