diff --git a/src/fully-connected-output.c b/src/fully-connected-output.c index b3a395bf..250ef5dd 100644 --- a/src/fully-connected-output.c +++ b/src/fully-connected-output.c @@ -34,7 +34,7 @@ static void pack_input_matrix( for (size_t outer_subblock_offset = 0; outer_subblock_offset < outer_subblock_size; outer_subblock_offset += 1) { const size_t index = (outer_block_start + outer_subblock_start + outer_subblock_offset) * input_channels + input_channel; const size_t packed_index = outer_block_start * input_channels + input_channels_block_start * outer_block_size + - outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_max + outer_subblock_offset; + outer_subblock_start * input_channels_block_size + input_channels_block_offset * outer_subblock_size + outer_subblock_offset; packed_matrix[packed_index] = matrix[index]; } } @@ -116,15 +116,12 @@ static void compute_matrix_multiplication( const nnp_fast_sgemm_function fast_sgemm = context->fast_sgemm_function; const nnp_full_sgemm_function full_sgemm = context->full_sgemm_function; - const size_t batch_block_stride = round_up(batch_block_size, batch_subblock_max); - const size_t output_channels_block_stride = round_up(output_channels_block_size, output_channels_subblock_max); - for (size_t output_channels_subblock_start = 0; output_channels_subblock_start < output_channels_block_size; output_channels_subblock_start += output_channels_subblock_max) { const size_t output_channels_subblock_size = min(output_channels_block_size - output_channels_subblock_start, output_channels_subblock_max); if ((batch_subblock_size == batch_subblock_max) && (output_channels_subblock_size == output_channels_subblock_max)) { fast_sgemm( input_channels_block_size, input_channels_block_start, - &input[batch_block_start * input_channels + input_channels_block_start * batch_block_stride + batch_subblock_start * input_channels_block_size], + &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size], &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size], &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)], output_channels); @@ -132,7 +129,7 @@ static void compute_matrix_multiplication( full_sgemm( batch_subblock_size, output_channels_subblock_size, input_channels_block_size, input_channels_block_start, - &input[batch_block_start * input_channels + input_channels_block_start * batch_block_stride + batch_subblock_start * input_channels_block_size], + &input[batch_block_start * input_channels + input_channels_block_start * batch_block_size + batch_subblock_start * input_channels_block_size], &kernel[(output_channels_block_start + output_channels_subblock_start) * input_channels_block_size], &output[(batch_block_start + batch_subblock_start) * output_channels + (output_channels_block_start + output_channels_subblock_start)], output_channels);