Skip to content

Commit

Permalink
Fully connected output: fix bug when batch is not divisible by 4. Close
Browse files Browse the repository at this point in the history
  • Loading branch information
Maratyszcza committed Mar 22, 2017
1 parent 91c3d15 commit d64ddf0
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/fully-connected-output.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static void pack_input_matrix(
for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) {
const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel;
const size_t packed_index = outer_block_start * input_channels + input_channels_block_start * outer_block_size +
outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_max + outer_subblock_offset;
outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_size + outer_subblock_offset;
packed_matrix[packed_index] = matrix[index];
}
}
Expand Down Expand Up @@ -116,23 +116,20 @@ static void compute_matrix_multiplication(
const nnp_fast_sgemm_function fast_sgemm = context->fast_sgemm_function;
const nnp_full_sgemm_function full_sgemm = context->full_sgemm_function;

const size_t batch_block_stride = round_up(batch_block_size, batch_subblock_max);
const size_t output_channels_block_stride = round_up(output_channels_block_size, output_channels_subblock_max);

for (size_t output_channels_subblock_start = 0; output_channels_subblock_start < output_channels_block_size; output_channels_subblock_start += output_channels_subblock_max) {
const size_t output_channels_subblock_size = min(output_channels_block_size - output_channels_subblock_start, output_channels_subblock_max);
if ((batch_subblock_size == batch_subblock_max) && (output_channels_subblock_size == output_channels_subblock_max)) {
fast_sgemm(
input_channels_block_size, input_channels_block_start,
&input[batch_block_start * input_channels + input_channels_block_start * batch_block_stride + batch_subblock_start * input_channels_block_size],
&input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size],
&kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size],
&output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)],
output_channels);
} else {
full_sgemm(
batch_subblock_size, output_channels_subblock_size,
input_channels_block_size, input_channels_block_start,
&input[batch_block_start * input_channels + input_channels_block_start * batch_block_stride + batch_subblock_start * input_channels_block_size],
&input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size],
&kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size],
&output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)],
output_channels);
Expand Down

0 comments on commit d64ddf0

Please sign in to comment.