Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load kmodel from stream #982

Merged
merged 18 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/Native/include/nncase/runtime/char_array_buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <gsl/gsl-lite.hpp>
#include <iostream>

namespace nncase {
class char_array_buffer : public std::streambuf {
public:
char_array_buffer(gsl::span<const char> data)
: begin_(data.begin()), end_(data.end()), current_(data.data()) {}

private:
int_type underflow() {
if (current_ == end_)
return traits_type::eof();

return traits_type::to_int_type(*current_);
}

int_type uflow() {
if (current_ == end_)
return traits_type::eof();

return traits_type::to_int_type(*current_++);
}

int_type pbackfail(int_type ch) {
if (current_ == begin_ ||
(ch != traits_type::eof() && ch != current_[-1]))
return traits_type::eof();

return traits_type::to_int_type(*--current_);
}

std::streamsize showmanyc() {
assert(std::less_equal<const char *>()(current_, end_));
return end_ - current_;
}

std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way,
[[maybe_unused]] std::ios_base::openmode which) {
if (way == std::ios_base::beg) {
current_ = begin_ + off;
} else if (way == std::ios_base::cur) {
current_ += off;
} else if (way == std::ios_base::end) {
current_ = end_ + off;
}

if (current_ < begin_ || current_ > end_)
return -1;

return current_ - begin_;
}

std::streampos seekpos(std::streampos sp,
[[maybe_unused]] std::ios_base::openmode which) {
current_ = begin_ + sp;

if (current_ < begin_ || current_ > end_)
return -1;

return current_ - begin_;
}

const char *const begin_;
const char *const end_;
const char *current_;
};
} // namespace nncase
6 changes: 5 additions & 1 deletion src/Native/include/nncase/runtime/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "runtime_module.h"
#include "runtime_tensor.h"
#include <gsl/gsl-lite.hpp>
#include <istream>
#include <memory>
#include <nncase/shape.h>
#include <nncase/tensor.h>
Expand Down Expand Up @@ -68,6 +69,8 @@ class NNCASE_API interpreter {
[[nodiscard]] result<void> load_model(gsl::span<const gsl::byte> buffer,
bool copy_buffer = false) noexcept;

[[nodiscard]] result<void> load_model(std::istream &stream) noexcept;

options_dict &options() noexcept;
result<runtime_module *> find_module_by_id(size_t index) noexcept;

Expand Down Expand Up @@ -102,9 +105,10 @@ class NNCASE_API interpreter {
tensor_type input_tensor_type(size_t index) const noexcept;
tensor_type output_tensor_type(size_t index) const noexcept;

result<void> initialize_model(const model_header &header) noexcept;

private:
std::shared_ptr<nncase::runtime::dump_manager> dump_manager_;
std::unique_ptr<gsl::byte[]> model_data_;
std::vector<std::unique_ptr<runtime_module>> modules_;
runtime_function *entry_function_;
options_dict options_;
Expand Down
10 changes: 8 additions & 2 deletions src/Native/include/nncase/runtime/runtime_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#pragma once
#include "model.h"
#include "result.h"
#include "runtime_section_context.h"
#include <nncase/runtime/stream_reader.h>
#include <nncase/type.h>
#include <nncase/value.h>

Expand All @@ -24,10 +26,10 @@ class interpreter;
class runtime_module;
struct runtime_module_init_context;

struct NNCASE_API runtime_function_init_context {
struct NNCASE_API runtime_function_init_context
: public runtime_section_context {
virtual runtime_module_init_context &module_init_context() noexcept = 0;
virtual const function_header &header() noexcept = 0;
virtual gsl::span<const gsl::byte> section(const char *name) noexcept = 0;
};

class NNCASE_API runtime_function {
Expand All @@ -40,6 +42,10 @@ class NNCASE_API runtime_function {
result<void>
initialize(gsl::span<const gsl::byte> payload,
runtime_module_init_context &module_init_context) noexcept;
result<void>
initialize(stream_reader &reader,
runtime_module_init_context &module_init_context) noexcept;

runtime_module &module() const noexcept;

uint32_t parameters_size() const noexcept;
Expand Down
9 changes: 6 additions & 3 deletions src/Native/include/nncase/runtime/runtime_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
#include "model.h"
#include "result.h"
#include "runtime_function.h"
#include "runtime_section_context.h"
#include "span_reader.h"
#include "stream_reader.h"
#include <nncase/kernels/kernel_context.h>

BEGIN_NS_NNCASE_RUNTIME

class interpreter;

struct NNCASE_API runtime_module_init_context {
virtual bool is_section_pinned() const noexcept = 0;
struct NNCASE_API runtime_module_init_context : public runtime_section_context {
virtual interpreter &interp() noexcept = 0;
virtual const module_header &header() noexcept = 0;
virtual gsl::span<const gsl::byte> section(const char *name) noexcept = 0;
};

class NNCASE_API runtime_module {
Expand All @@ -49,6 +50,8 @@ class NNCASE_API runtime_module {

result<void> initialize(gsl::span<const gsl::byte> payload,
interpreter &interp) noexcept;
result<void> initialize(stream_reader &reader,
interpreter &interp) noexcept;
const module_kind_t &kind() const noexcept;

interpreter &interp() const noexcept { return *interp_; }
Expand Down
51 changes: 51 additions & 0 deletions src/Native/include/nncase/runtime/runtime_section_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "host_buffer.h"
#include "model.h"
#include "result.h"
#include "span_reader.h"
#include "stream_reader.h"
#include <nncase/type.h>
#include <nncase/value.h>

BEGIN_NS_NNCASE_RUNTIME

struct NNCASE_API runtime_section_context {
virtual bool is_section_pinned() const noexcept = 0;
virtual result<gsl::span<const gsl::byte>>
section(const char *name) noexcept = 0;
virtual result<stream_reader *>
seek_section(const char *name, section_header &header) noexcept = 0;

result<gsl::span<const gsl::byte>>
get_or_read_section(const char *name, host_buffer_t &storage,
bool allocate_shared) noexcept;

template <class TCallable>
result<void> read_section(const char *name, TCallable &&callable) noexcept {
auto section_span_r = section(name);
if (section_span_r.is_ok()) {
span_reader sr(std::move(section_span_r).unwrap());
return callable(sr, sr.avail());
} else {
section_header header;
try_var(sr, seek_section(name, header));
return callable(*sr, (size_t)header.body_size);
}
}
};

END_NS_NNCASE_RUNTIME
75 changes: 75 additions & 0 deletions src/Native/include/nncase/runtime/stream_reader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstring>
#include <gsl/gsl-lite.hpp>
#include <istream>
#include <iterator>
#include <nncase/compiler_defs.h>
#include <nncase/runtime/dbg.h>
#include <string>
#include <vector>

BEGIN_NS_NNCASE_RUNTIME

class stream_reader {
public:
stream_reader(std::istream &stream) : stream_(stream) {}

std::streampos tell() const noexcept { return stream_.tellg(); }
bool empty() const noexcept { return !stream_.eof(); }

void seek(std::streampos pos) noexcept { stream_.seekg(pos); }

template <class T> T read() {
T value;
read(value);
return value;
}

template <class T> T read_unaligned() { return read<T>(); }

template <class T> T peek() {
T value;
auto pos = tell();
read(value);
seek(pos);
return value;
}

template <class T> T peek_unaligned() { return peek<T>(); }

template <class T> void read(T &value) {
stream_.read(reinterpret_cast<char *>(&value), sizeof(value));
}

template <class T> void read_span(gsl::span<T> span) {
size_t sub_data_size = 8388608;
for (size_t pos = 0; pos < span.size_bytes();) {
if (pos + sub_data_size >= span.size_bytes())
sub_data_size = span.size_bytes() - pos;
stream_.read(reinterpret_cast<char *>(span.data()) + pos,
sub_data_size);
pos += sub_data_size;
}
}

void skip(size_t count) { stream_.seekg(count, std::ios::cur); }

private:
std::istream &stream_;
};

END_NS_NNCASE_RUNTIME
4 changes: 4 additions & 0 deletions src/Native/include/nncase/runtime/type_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <nncase/runtime/result.h>
#include <nncase/runtime/span_reader.h>
#include <nncase/runtime/stream_reader.h>
#include <nncase/type.h>

BEGIN_NS_NNCASE_RUNTIME
Expand All @@ -31,4 +32,7 @@ typedef enum : uint8_t {
result<type> deserialize_type(span_reader &sr) noexcept;
result<datatype_t> deserialize_datatype(span_reader &sr) noexcept;

result<type> deserialize_type(stream_reader &sr) noexcept;
result<datatype_t> deserialize_datatype(stream_reader &sr) noexcept;

END_NS_NNCASE_RUNTIME
1 change: 1 addition & 0 deletions src/Native/src/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(SRCS buffer.cpp
host_buffer.cpp
host_runtime_tensor.cpp
interpreter.cpp
runtime_section_context.cpp
runtime_loader.cpp
runtime_module.cpp
runtime_function.cpp
Expand Down
12 changes: 9 additions & 3 deletions src/Native/src/runtime/host_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,24 @@ class host_buffer_allocator : public buffer_allocator {
auto data = new (std::nothrow) gsl::byte[bytes];
if (!data)
return err(std::errc::not_enough_memory);
auto paddr =
options.flags & HOST_BUFFER_ALLOCATE_SHARED ? (uintptr_t)data : 0;
return ok<buffer_t>(object_t<host_buffer_impl>(
std::in_place, data, bytes, [](gsl::byte *p) { delete[] p; }, 0,
std::in_place, data, bytes, [](gsl::byte *p) { delete[] p; }, paddr,
*this, host_sync_status_t::valid, true));
}

result<buffer_t>
attach([[maybe_unused]] gsl::span<gsl::byte> data,
[[maybe_unused]] const buffer_attach_options &options) override {
auto paddr = options.flags & HOST_BUFFER_ATTACH_SHARED
? (options.physical_address ? options.physical_address
: (uintptr_t)data.data())
: 0;
return ok<buffer_t>(object_t<host_buffer_impl>(
std::in_place, data.data(), data.size_bytes(),
[]([[maybe_unused]] gsl::byte *p) {}, options.physical_address,
*this, host_sync_status_t::valid));
[]([[maybe_unused]] gsl::byte *p) {}, paddr, *this,
host_sync_status_t::valid));
}
};

Expand Down
Loading