Skip to content

Commit

Permalink
trt8 try #1
Browse files Browse the repository at this point in the history
  • Loading branch information
vjsrinivas committed Jul 22, 2022
1 parent b343e1a commit 8dd4894
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
19 changes: 10 additions & 9 deletions detr/calibrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fstream>
#include <algorithm>
#include "common.hpp"
#include "macros.h"

//! \class Int8EntropyCalibrator2
//!
Expand All @@ -20,11 +21,11 @@ class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
const char* img_dir, const char* calib_table_name,
const char* input_blob_name, bool read_cache = true);

virtual ~Int8EntropyCalibrator2();
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
const void* readCalibrationCache(size_t& length) override;
void writeCalibrationCache(const void* cache, size_t length) override;
virtual ~Int8EntropyCalibrator2() TRT_NOEXCEPT;
int getBatchSize() const TRT_NOEXCEPT override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;

private:
int batchsize_;
Expand Down Expand Up @@ -62,11 +63,11 @@ Int8EntropyCalibrator2::~Int8EntropyCalibrator2() {
CUDA_CHECK(cudaFree(device_input_));
}

int Int8EntropyCalibrator2::getBatchSize() const {
int Int8EntropyCalibrator2::getBatchSize() const TRT_NOEXCEPT {
return batchsize_;
}

bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) {
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT {
if (img_idx_ + batchsize_ > static_cast<int>(img_files_.size())) {
return false;
}
Expand Down Expand Up @@ -97,7 +98,7 @@ bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int
return true;
}

const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) {
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) TRT_NOEXCEPT {
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
calib_cache_.clear();
std::ifstream input(calib_table_name_, std::ios::binary);
Expand All @@ -109,7 +110,7 @@ const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) {
return length ? calib_cache_.data() : nullptr;
}

void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) {
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT {
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
std::ofstream output(calib_table_name_, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
Expand Down
8 changes: 6 additions & 2 deletions detr/detr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ int num_heads = 8
v_shuffle->setReshapeDimensions(Dims3{ -1, num_heads, head_dim });
v_shuffle->setSecondTranspose(Permutation{ 1, 0, 2 });

auto q_product_k = network->addMatrixMultiply(*q_shuffle->getOutput(0), false, *k_shuffle->getOutput(0), true);
// OLD parameter: input0, doTransposeOnInput0, input1, doTransposeOnInput1
//auto q_product_k = network->addMatrixMultiply(*q_shuffle->getOutput(0), false, *k_shuffle->getOutput(0), true);
//std::cout << "q vector dimension: " << *q_shuffle->getOutput(0).getDimensions() << std::endl;
auto q_product_k = network->addMatrixMultiply(*q_shuffle->getOutput(0), MatrixOperation::kNONE, *k_shuffle->getOutput(0), MatrixOperation::kTRANSPOSE);
assert(q_product_k);

// src_key_padding_mask are all false, so do nothing here
Expand All @@ -181,7 +184,8 @@ int num_heads = 8
assert(softmax);
softmax->setAxes(4);

auto attn_product_v = network->addMatrixMultiply(*softmax->getOutput(0), false, *v_shuffle->getOutput(0), false);
//auto attn_product_v = network->addMatrixMultiply(*softmax->getOutput(0), false, *v_shuffle->getOutput(0), false);
auto attn_product_v = network->addMatrixMultiply(*softmax->getOutput(0), MatrixOperation::kNONE, *v_shuffle->getOutput(0), MatrixOperation::kNONE);
assert(attn_product_v);

auto attn_shuffle = network->addShuffle(*attn_product_v->getOutput(0));
Expand Down
3 changes: 2 additions & 1 deletion detr/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <ostream>
#include <sstream>
#include <string>
#include "macros.h"

using Severity = nvinfer1::ILogger::Severity;

Expand Down Expand Up @@ -208,7 +209,7 @@ class Logger : public nvinfer1::ILogger {
//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
//! inheritance from nvinfer1::ILogger
//!
void log(Severity severity, const char* msg) override {
void log(Severity severity, const char* msg) TRT_NOEXCEPT override {
LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
}

Expand Down
12 changes: 12 additions & 0 deletions detr/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef __MACROS_H
#define __MACROS_H

#if NV_TENSORRT_MAJOR >= 8
#define TRT_NOEXCEPT noexcept
#define TRT_CONST_ENQUEUE const
#else
#define TRT_NOEXCEPT
#define TRT_CONST_ENQUEUE
#endif

#endif // __MACROS_H

0 comments on commit 8dd4894

Please sign in to comment.