Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IO] Return bytes written in Stream::Write #686

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/dmlc/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ class Stream { // NOLINT(*)
/*!
* \brief reads data from a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \return the size of data read
* \param size The maximum number of bytes to read
* \return The number of bytes read from the stream
*/
virtual size_t Read(void *ptr, size_t size) = 0;
virtual size_t Read(void* ptr, size_t size) = 0;
/*!
* \brief writes data to a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \param size The maximum number of bytes to write
* \return The number of bytes written
*/
virtual void Write(const void *ptr, size_t size) = 0;
virtual size_t Write(const void* ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~Stream(void) {}
/*!
Expand Down
22 changes: 12 additions & 10 deletions include/dmlc/memory_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@ struct MemoryFixedSizeStream : public SeekStream {
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
CHECK(curr_ptr_ + size <= buffer_size_);
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
virtual size_t Write(const void *ptr, size_t size) override {
if (size == 0) return 0;
CHECK(curr_ptr_ + size <= buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
return size;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_ptr_;
}

Expand All @@ -73,25 +74,26 @@ struct MemoryStringStream : public dmlc::SeekStream {
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
CHECK(curr_ptr_ <= p_buffer_->length());
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
virtual size_t Write(const void *ptr, size_t size) override {
if (size == 0) return 0;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
return size;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_ptr_;
}

Expand Down
9 changes: 5 additions & 4 deletions src/io/hdfs_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class HDFSStream : public SeekStream {
}
}

virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
char *buf = static_cast<char*>(ptr);
size_t nleft = size;
size_t nmax = static_cast<size_t>(std::numeric_limits<tSize>::max());
Expand All @@ -48,7 +48,7 @@ class HDFSStream : public SeekStream {
return size - nleft;
}

virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
const char *buf = reinterpret_cast<const char*>(ptr);
size_t nleft = size;
// When using builtin-java classes to write, the maximum write size
Expand All @@ -70,14 +70,15 @@ class HDFSStream : public SeekStream {
LOG(FATAL) << "HDFSStream.hdfsWrite Error:" << strerror(errsv);
}
}
return size - nleft;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
if (hdfsSeek(fs_, fp_, pos) != 0) {
int errsv = errno;
LOG(FATAL) << "HDFSStream.hdfsSeek Error:" << strerror(errsv);
}
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
tOffset offset = hdfsTell(fs_, fp_);
if (offset == -1) {
int errsv = errno;
Expand Down
9 changes: 5 additions & 4 deletions src/io/local_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,22 @@ class FileStream : public SeekStream {
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
CHECK(std::fwrite(ptr, 1, size, fp_) == size)
<< "FileStream.Write incomplete";
return 0;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
#ifndef _MSC_VER
CHECK(!std::fseek(fp_, static_cast<long>(pos), SEEK_SET)); // NOLINT(*)
#else // _MSC_VER
CHECK(!_fseeki64(fp_, pos, SEEK_SET));
#endif // _MSC_VER
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
#ifndef _MSC_VER
return std::ftell(fp_);
#else // _MSC_VER
Expand Down
16 changes: 9 additions & 7 deletions src/io/s3_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,23 +424,24 @@ class CURLReadStreamBase : public SeekStream {
virtual ~CURLReadStreamBase() {
this->Cleanup();
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_bytes_;
}
virtual bool AtEnd(void) const {
return at_end_;
}
virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
LOG(FATAL) << "CURL.ReadStream cannot be used for write";
return 0;
}
// lazy seek function
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
if (curr_bytes_ != pos) {
this->Cleanup();
curr_bytes_ = pos;
}
}
virtual size_t Read(void *ptr, size_t size);
virtual size_t Read(void *ptr, size_t size) override ;

protected:
CURLReadStreamBase()
Expand Down Expand Up @@ -790,11 +791,11 @@ class WriteStream : public Stream {
ecurl_ = curl_easy_init();
this->Init();
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
LOG(FATAL) << "S3.WriteStream cannot be used for read";
return 0;
}
virtual void Write(const void *ptr, size_t size);
virtual size_t Write(const void *ptr, size_t size) override;
// destructor
virtual ~WriteStream() {
this->Close();
Expand Down Expand Up @@ -863,13 +864,14 @@ class WriteStream : public Stream {
void Finish(void);
};

void WriteStream::Write(const void *ptr, size_t size) {
size_t WriteStream::Write(const void *ptr, size_t size) {
size_t rlen = buffer_.length();
buffer_.resize(rlen + size);
std::memcpy(BeginPtr(buffer_) + rlen, ptr, size);
if (buffer_.length() >= max_buffer_size_) {
this->Upload();
}
return size;
}

void WriteStream::Run(const std::string &method,
Expand Down
3 changes: 2 additions & 1 deletion src/io/single_file_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ class SingleFileSplit : public InputSplit {
CHECK(part_index == 0 && num_parts == 1);
this->BeforeFirst();
}
virtual void Write(const void * /*ptr*/, size_t /*size*/) {
virtual size_t Write(const void * /*ptr*/, size_t /*size*/) {
LOG(FATAL) << "InputSplit do not support write";
return 0;
}
virtual bool NextRecord(Blob *out_rec) {
if (chunk_begin_ == chunk_end_) {
Expand Down