Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 29, 2024
1 parent 2a6fd83 commit bd66ba1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
21 changes: 13 additions & 8 deletions src/layer/arm/lstm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,35 +1349,40 @@ int LSTM_arm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat

Mat hidden;
Mat cell;
Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
if (bottom_blobs.size() == 2)
Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator;
if (bottom_blobs.size() == 3)
{
if (elemtype == 1)
{
hidden = bottom_blobs[1].clone(hidden_allocator);
cell = bottom_blobs[2].clone(hidden_allocator);
hidden = bottom_blobs[1].clone(hidden_cell_allocator);
cell = bottom_blobs[2].clone(hidden_cell_allocator);
}
if (elemtype == 2)
{
Option opt_cast = opt;
opt_cast.blob_allocator = hidden_allocator;
opt_cast.blob_allocator = hidden_cell_allocator;
cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast);
cast_float16_to_float32(bottom_blobs[2], cell, opt_cast);
}
if (elemtype == 4)
{
Option opt_cast = opt;
opt_cast.blob_allocator = hidden_allocator;
opt_cast.blob_allocator = hidden_cell_allocator;
cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast);
cast_bfloat16_to_float32(bottom_blobs[2], cell, opt_cast);
}
}
else
{
hidden.create(num_output, num_directions, 4u, hidden_allocator);
hidden.create(num_output, num_directions, 4u, hidden_cell_allocator);
if (hidden.empty())
return -100;
hidden.fill(0.f);

cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
}

Mat& top_blob = top_blobs[0];
Expand Down Expand Up @@ -1435,7 +1440,7 @@ int LSTM_arm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat
}
}

if (top_blobs.size() == 2)
if (top_blobs.size() == 3)
{
if (elemtype == 1)
{
Expand Down
36 changes: 18 additions & 18 deletions src/layer/arm/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
float* bias_c_IFOG = bias_c_tm_dr.row(0);

int q = 0;
for (; q < num_output; q++)
for (; q < hidden_size; q++)
{
bias_c_IFOG[0] = bias_c_I[q];
bias_c_IFOG[1] = bias_c_F[q];
Expand All @@ -64,15 +64,15 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x

bias_c_IFOG += 4;

const signed char* weight_xc_I = weight_xc_dr.row<const signed char>(num_output * 0 + q);
const signed char* weight_xc_F = weight_xc_dr.row<const signed char>(num_output * 1 + q);
const signed char* weight_xc_O = weight_xc_dr.row<const signed char>(num_output * 2 + q);
const signed char* weight_xc_G = weight_xc_dr.row<const signed char>(num_output * 3 + q);
const signed char* weight_xc_I = weight_xc_dr.row<const signed char>(hidden_size * 0 + q);
const signed char* weight_xc_F = weight_xc_dr.row<const signed char>(hidden_size * 1 + q);
const signed char* weight_xc_O = weight_xc_dr.row<const signed char>(hidden_size * 2 + q);
const signed char* weight_xc_G = weight_xc_dr.row<const signed char>(hidden_size * 3 + q);

const signed char* weight_hc_I = weight_hc_dr.row<const signed char>(num_output * 0 + q);
const signed char* weight_hc_F = weight_hc_dr.row<const signed char>(num_output * 1 + q);
const signed char* weight_hc_O = weight_hc_dr.row<const signed char>(num_output * 2 + q);
const signed char* weight_hc_G = weight_hc_dr.row<const signed char>(num_output * 3 + q);
const signed char* weight_hc_I = weight_hc_dr.row<const signed char>(hidden_size * 0 + q);
const signed char* weight_hc_F = weight_hc_dr.row<const signed char>(hidden_size * 1 + q);
const signed char* weight_hc_O = weight_hc_dr.row<const signed char>(hidden_size * 2 + q);
const signed char* weight_hc_G = weight_hc_dr.row<const signed char>(hidden_size * 3 + q);

signed char* kptr = weight_data_tm_dr.row<signed char>(q);
float* descales_ptr = weight_data_tm_int8_descales_dr.row(q);
Expand All @@ -95,14 +95,14 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
kptr += 4;
}

descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[num_output * 0 + q];
descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[num_output * 1 + q];
descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[num_output * 2 + q];
descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[num_output * 3 + q];
descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[num_output * 0 + q];
descales_ptr[5] = 1.f / weight_hc_int8_scales_ptr[num_output * 1 + q];
descales_ptr[6] = 1.f / weight_hc_int8_scales_ptr[num_output * 2 + q];
descales_ptr[7] = 1.f / weight_hc_int8_scales_ptr[num_output * 3 + q];
descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q];
descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q];
descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q];
descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q];
descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q];
descales_ptr[5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q];
descales_ptr[6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q];
descales_ptr[7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q];
}
}
}
Expand All @@ -124,7 +124,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
int T = bottom_blob_int8.h;

int num_output = top_blob.w;
int hidden_size = hidden_state.w;
int hidden_size = cell_state.w;

// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
Expand Down

0 comments on commit bd66ba1

Please sign in to comment.