diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index bd3bec903180..20806f5ff136 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -129,6 +129,43 @@ std::string SaveParams(const Map<String, NDArray>& params); * \param params Parameters to save. */ void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params); + +/*! + * \brief A dmlc stream which wraps standard file operations. + */ +struct SimpleBinaryFileStream : public dmlc::Stream { + public: + SimpleBinaryFileStream(const std::string& path, std::string mode) { + const char* fname = path.c_str(); + + CHECK(mode == "wb" || mode == "rb") << "Only allowed modes are 'wb' and 'rb'"; + read_ = mode == "rb"; + fp_ = std::fopen(fname, mode.c_str()); + CHECK(fp_ != nullptr) << "Unable to open file " << path; + } + virtual ~SimpleBinaryFileStream(void) { this->Close(); } + virtual size_t Read(void* ptr, size_t size) { + CHECK(read_) << "File opened in write-mode, cannot read."; + CHECK(fp_ != nullptr) << "File is closed"; + return std::fread(ptr, 1, size, fp_); + } + virtual void Write(const void* ptr, size_t size) { + CHECK(!read_) << "File opened in read-mode, cannot write."; + CHECK(fp_ != nullptr) << "File is closed"; + CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "SimpleBinaryFileStream.Write incomplete"; + } + inline void Close(void) { + if (fp_ != nullptr) { + std::fclose(fp_); + fp_ = nullptr; + } + } + + private: + std::FILE* fp_ = nullptr; + bool read_; +}; // class SimpleBinaryFileStream + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_FILE_UTILS_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 082ff0556544..5696bc5314c1 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -344,10 +344,8 @@ void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byt } void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) { - std::string bytes; - dmlc::MemoryStringStream stream(&bytes); + tvm::runtime::SimpleBinaryFileStream stream(path, "wb"); MoveLateBoundConstantsToStream(&stream, byte_limit); - SaveBinaryToFile(path, bytes); } void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) { @@ -381,9 +379,7 @@ void Executable::LoadLateBoundConstantsFromMap(Map<String, NDArray> map) { } void Executable::LoadLateBoundConstantsFromFile(const std::string& path) { - std::string bytes; - LoadBinaryFromFile(path, &bytes); - dmlc::MemoryStringStream stream(&bytes); + tvm::runtime::SimpleBinaryFileStream stream(path, "rb"); LoadLateBoundConstantsFromStream(&stream); } @@ -1063,22 +1059,16 @@ Module ExecutableLoadBinary(void* strm) { } void Executable::SaveToFile(const std::string& path, const std::string& format) { - std::string data; - dmlc::MemoryStringStream writer(&data); - dmlc::SeekStream* strm = &writer; - SaveToBinary(strm); - SaveBinaryToFile(path, data); + tvm::runtime::SimpleBinaryFileStream stream(path, "wb"); + SaveToBinary(&stream); } TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary); // Load module from module. Module ExecutableLoadFile(const std::string& file_name, const std::string& format) { - std::string data; - LoadBinaryFromFile(file_name, &data); - dmlc::MemoryStringStream reader(&data); - dmlc::Stream* strm = &reader; - auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm)); + tvm::runtime::SimpleBinaryFileStream stream(file_name, "rb"); + auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(&stream)); return exec; }