Skip to content

Commit

Permalink
ukernel test improvements (#15542)
Browse files Browse the repository at this point in the history
* Consistently compare with/without skipping of intermediate roundings.
A catch is that the ukernel may fall back to a generic code path (and
that fallback is consistently exercised by the test, even when a
non-fallback path is also available and tested). And generic code paths
("tile functions") never skipped intermediate roundings, even if allowed
to by the flag. This caused complicated test code retrying again on
error. This PR simply adds the skipping-intermediate-roundings generic
tile functions, so the test code is simpler, and concretely I just
needed that for #15543 as I'm adding bf16-accumulator tile functions
that are skipping intermediate roundings.
* I had to also update `iree-e2e-matmul-test` to switch to skipping
intermediate roundings. Unlike the ukernels' own tests, which really
must test both flavors, in `iree-e2e-matmul-test` we are e2e testing
what the compiler produces, and that is skippig intermediate roundings
at least by default, and while that could be overridden with
`--iree-llvmcpu-skip-intermediate-roundings=false`, we don't currently
test that in e2e matmul tests.
* Generate better random test input values. Some were too large - when
we generate random bfloat16 to accumulate into bfloat16, they better be
very small as we don't want to grow accumulators to the point where they
would start rounding. It's OK, because bfloat16 kernels use bfloat16
arithmetic instructions, not bit hacks, so correctness is sufficiently
tested on very small values. Conversely, for int8/int16 test input
values, we were generating a very narrow range and that was potentially
missing important coverage as some of our int kernels are starting to do
evil bit hacks (#15525).
  • Loading branch information
bjacob authored Nov 14, 2023
1 parent ef0f1a4 commit 16e4346
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 50 deletions.
90 changes: 86 additions & 4 deletions runtime/src/iree/builtins/ukernel/mmt4d_tile.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ static void iree_uk_mmt4d_tile_f16f16f32_generic(
}

// Generic implementation of matmul tile, f16*f16->f16 case.
static void iree_uk_mmt4d_tile_f16f16f16_generic(
// Not skipping intermediate roundings.
static void iree_uk_mmt4d_tile_f16f16f16_generic_noskipround(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t* out_tile = out_tile_untyped;
Expand Down Expand Up @@ -226,6 +227,44 @@ static void iree_uk_mmt4d_tile_f16f16f16_generic(
for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}

// Generic implementation of matmul tile, f16*f16->f16 case.
// Skipping intermediate roundings.
static void iree_uk_mmt4d_tile_f16f16f16_generic_skipround(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t* out_tile = out_tile_untyped;
const iree_uk_uint16_t* lhs_panel = lhs_panel_untyped;
const iree_uk_uint16_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
// Initialize the local accumulator tile.
float acc_f32[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < M0 * N0; ++i)
acc_f32[i] = iree_uk_f16_to_f32(out_tile[i]);
} else {
for (int i = 0; i < M0 * N0; ++i) acc_f32[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
float lhs_f32 = iree_uk_f16_to_f32(lhs_panel[i0 * K0 + k0]);
float rhs_f32 = iree_uk_f16_to_f32(rhs_panel[j0 * K0 + k0]);
acc_f32[i0 * N0 + j0] += lhs_f32 * rhs_f32;
}
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
// Store the local accumulator tile to the destination.
for (int i = 0; i < M0 * N0; ++i)
out_tile[i] = iree_uk_f32_to_f16(acc_f32[i]);
}

// Generic implementation of matmul tile, bf16*bf16->f32 case.
static void iree_uk_mmt4d_tile_bf16bf16f32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
Expand Down Expand Up @@ -262,7 +301,8 @@ static void iree_uk_mmt4d_tile_bf16bf16f32_generic(
}

// Generic implementation of matmul tile, bf16*bf16->bf16 case.
static void iree_uk_mmt4d_tile_bf16bf16bf16_generic(
// Not skipping intermediate roundings.
static void iree_uk_mmt4d_tile_bf16bf16bf16_generic_noskipround(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t* out_tile = out_tile_untyped;
Expand Down Expand Up @@ -298,6 +338,44 @@ static void iree_uk_mmt4d_tile_bf16bf16bf16_generic(
for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}

// Generic implementation of matmul tile, bf16*bf16->bf16 case.
// Skipping intermediate roundings.
static void iree_uk_mmt4d_tile_bf16bf16bf16_generic_skipround(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t* out_tile = out_tile_untyped;
const iree_uk_uint16_t* lhs_panel = lhs_panel_untyped;
const iree_uk_uint16_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
// Initialize the local accumulator tile.
float acc_f32[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < M0 * N0; ++i)
acc_f32[i] = iree_uk_bf16_to_f32(out_tile[i]);
} else {
for (int i = 0; i < M0 * N0; ++i) acc_f32[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
float lhs_f32 = iree_uk_bf16_to_f32(lhs_panel[i0 * K0 + k0]);
float rhs_f32 = iree_uk_bf16_to_f32(rhs_panel[j0 * K0 + k0]);
acc_f32[i0 * N0 + j0] += lhs_f32 * rhs_f32;
}
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
// Store the local accumulator tile to the destination.
for (int i = 0; i < M0 * N0; ++i)
out_tile[i] = iree_uk_f32_to_bf16(acc_f32[i]);
}

static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
const iree_uk_mmt4d_params_t* params) {
switch (iree_uk_mmt4d_type(params->flags)) {
Expand All @@ -312,11 +390,15 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
case iree_uk_mmt4d_type_f16f16f32:
return iree_uk_mmt4d_tile_f16f16f32_generic;
case iree_uk_mmt4d_type_f16f16f16:
return iree_uk_mmt4d_tile_f16f16f16_generic;
return (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS)
? iree_uk_mmt4d_tile_f16f16f16_generic_skipround
: iree_uk_mmt4d_tile_f16f16f16_generic_noskipround;
case iree_uk_mmt4d_type_bf16bf16f32:
return iree_uk_mmt4d_tile_bf16bf16f32_generic;
case iree_uk_mmt4d_type_bf16bf16bf16:
return iree_uk_mmt4d_tile_bf16bf16bf16_generic;
return (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS)
? iree_uk_mmt4d_tile_bf16bf16bf16_generic_skipround
: iree_uk_mmt4d_tile_bf16bf16bf16_generic_noskipround;
default:
// shouldn't happen, validated earlier.
IREE_UK_ASSUME_UNREACHABLE;
Expand Down
57 changes: 40 additions & 17 deletions runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ static void iree_mmt4d_reference_innerloop_bf16bf16f32(
*out_ptr = acc;
}

static void iree_mmt4d_reference_innerloop_bf16bf16bf16(
static void iree_mmt4d_reference_innerloop_bf16bf16bf16_noskipround(
uint16_t* out_ptr, const uint16_t* lhs_ptr, const uint16_t* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
uint16_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0;
Expand All @@ -123,6 +123,36 @@ static void iree_mmt4d_reference_innerloop_bf16bf16bf16(
*out_ptr = acc;
}

static void iree_mmt4d_reference_innerloop_bf16bf16bf16_skipround(
uint16_t* out_ptr, const uint16_t* lhs_ptr, const uint16_t* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
float acc_f32 = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE
? iree_math_bf16_to_f32(*out_ptr)
: 0.f;
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
float lhs_f32 =
iree_math_bf16_to_f32(lhs_ptr[k * params->M0 * params->K0 + k0]);
float rhs_f32 =
iree_math_bf16_to_f32(rhs_ptr[k * params->N0 * params->K0 + k0]);
acc_f32 += lhs_f32 * rhs_f32;
}
}
*out_ptr = iree_math_f32_to_bf16(acc_f32);
}

static void iree_mmt4d_reference_innerloop_bf16bf16bf16(
uint16_t* out_ptr, const uint16_t* lhs_ptr, const uint16_t* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
if (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) {
iree_mmt4d_reference_innerloop_bf16bf16bf16_skipround(out_ptr, lhs_ptr,
rhs_ptr, params);
} else {
iree_mmt4d_reference_innerloop_bf16bf16bf16_noskipround(out_ptr, lhs_ptr,
rhs_ptr, params);
}
}

static void iree_mmt4d_reference_innerloop_s8s8s32(
int32_t* out_ptr, const int8_t* lhs_ptr, const int8_t* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
Expand Down Expand Up @@ -346,22 +376,10 @@ static void iree_uk_test_mmt4d_for_shape_params(
// code accumulates in a different order compared to the actual code. This
// relies on picking input test matrix elements so that all intermediate
// values are exactly representable - i.e. small integer numerators.
// This also relies on honoring IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS
// consistently between actual tile functions (including generic fallback
// ones) and the reference code in this test.
bool fail = memcmp(actual_out_buffer, reference_out_buffer, out_buffer_size);
if (fail) {
// The one thing that causes legitimate bit differences at the moment is
// when we enable skipping intermediate roundings but the actual kernel does
// not skip intermediate roundings, such as when falling back on a generic
// code path. In that case, we retry with reference code not skipping
// intermediate roundings. This currently only happens when the output type
// is f16.
if (out_type == IREE_UK_TYPE_FLOAT_16 &&
(params.flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS)) {
reference_params.flags &= ~IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS;
memcpy(reference_out_buffer, init_out_buffer, out_buffer_size);
iree_mmt4d_reference(&reference_params);
fail = memcmp(actual_out_buffer, reference_out_buffer, out_buffer_size);
}
}

if (fail) {
IREE_UK_TEST_FAIL(test);
Expand All @@ -379,6 +397,11 @@ static void iree_uk_test_mmt4d_for_tile_params(iree_uk_test_t* test,
typedef struct shape_mnk_t {
int m, n, k;
} shape_mnk_t;
const iree_uk_mmt4d_type_t mmt4d_type =
iree_uk_mmt4d_type(((const iree_uk_mmt4d_params_t*)src_params)->flags);
const iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
const int max_reduction_size =
(out_type == IREE_UK_TYPE_BFLOAT_16) ? 100 : 1000;
const shape_mnk_t shapes[] = {
// Degenerate case M==0. Vacuous.
{0, 1, 1},
Expand All @@ -394,7 +417,7 @@ static void iree_uk_test_mmt4d_for_tile_params(iree_uk_test_t* test,
{1, 1, 1},
{1, 1, 2},
{1, 1, 10},
{1, 1, 1000},
{1, 1, max_reduction_size},
{2, 1, 1},
{1, 2, 1},
{2, 2, 2},
Expand Down
27 changes: 12 additions & 15 deletions runtime/src/iree/builtins/ukernel/tools/util.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e) {
return v & 1;
}

int iree_uk_random_engine_get_minus16_plus15(iree_uk_random_engine_t* e) {
int v = iree_uk_random_engine_get_0_65535(e);
return (v % 32) - 16;
}

void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
iree_uk_type_t type,
iree_uk_random_engine_t* engine) {
Expand Down Expand Up @@ -125,34 +120,36 @@ void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
// Small integers, should work for now for all the types we currently have
// and enable exact float arithmetic, allowing to keep tests simpler for
// now. Watch out for when we'll do float16!
int random_val = iree_uk_random_engine_get_minus16_plus15(engine);
int random_val = iree_uk_random_engine_get_0_65535(engine);
switch (type) {
case IREE_UK_TYPE_FLOAT_32:
((float*)buffer)[i] = random_val;
((float*)buffer)[i] = (random_val % 4) - 2;
break;
case IREE_UK_TYPE_FLOAT_16:
((uint16_t*)buffer)[i] = iree_math_f32_to_f16((float)random_val);
((uint16_t*)buffer)[i] =
iree_math_f32_to_f16((float)((random_val % 16) - 8));
break;
case IREE_UK_TYPE_BFLOAT_16:
((uint16_t*)buffer)[i] = iree_math_f32_to_bf16((float)random_val);
((uint16_t*)buffer)[i] =
iree_math_f32_to_bf16((float)((random_val % 4) - 2));
break;
case IREE_UK_TYPE_SINT_32:
((int32_t*)buffer)[i] = random_val;
((int32_t*)buffer)[i] = (random_val % 2048) - 512;
break;
case IREE_UK_TYPE_UINT_32:
((uint32_t*)buffer)[i] = random_val;
((uint32_t*)buffer)[i] = random_val % 2048;
break;
case IREE_UK_TYPE_SINT_16:
((int16_t*)buffer)[i] = random_val;
((int16_t*)buffer)[i] = (random_val % 2048) - 512;
break;
case IREE_UK_TYPE_UINT_16:
((uint16_t*)buffer)[i] = random_val;
((uint16_t*)buffer)[i] = random_val % 2048;
break;
case IREE_UK_TYPE_SINT_8:
((int8_t*)buffer)[i] = random_val;
((int8_t*)buffer)[i] = (random_val % 256) - 128;
break;
case IREE_UK_TYPE_UINT_8:
((uint8_t*)buffer)[i] = random_val;
((uint8_t*)buffer)[i] = random_val % 256;
break;
default:
IREE_UK_ASSERT(false && "unknown type");
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/builtins/ukernel/tools/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ static inline iree_uk_random_engine_t iree_uk_random_engine_init(void) {
iree_uk_uint32_t iree_uk_random_engine_get_uint32(iree_uk_random_engine_t* e);
iree_uk_uint64_t iree_uk_random_engine_get_uint64(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_0_65535(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_0_255(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_minus16_plus15(iree_uk_random_engine_t* e);
void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
iree_uk_type_t type,
iree_uk_random_engine_t* engine);
Expand Down
17 changes: 4 additions & 13 deletions tools/iree-e2e-matmul-test.c
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,8 @@ static void reference_matmul_f16_f16_f16_f16(
iree_hal_dim_t m, iree_hal_dim_t n) {
float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
acc = iree_math_round_to_nearest_f16(
iree_math_round_to_nearest_f16(
(iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
iree_math_f16_to_f32(rhs_data[n + k * n_size]))) +
acc);
acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
iree_math_f16_to_f32(rhs_data[n + k * n_size]);
}
result_data[n + m * n_size] = iree_math_f32_to_f16(acc);
}
Expand Down Expand Up @@ -460,11 +457,8 @@ static void reference_matmul_bf16_bf16_bf16_bf16(
iree_hal_dim_t m, iree_hal_dim_t n) {
float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
acc = iree_math_round_to_nearest_bf16(
iree_math_round_to_nearest_bf16(
(iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
iree_math_bf16_to_f32(rhs_data[n + k * n_size]))) +
acc);
acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
}
result_data[n + m * n_size] = iree_math_f32_to_bf16(acc);
}
Expand Down Expand Up @@ -672,9 +666,6 @@ static bool matmul_result_elements_agree(iree_e2e_test_value_t expected,
FLAG_acceptable_fp_delta;
case IREE_E2E_TEST_VALUE_TYPE_BF16:
if (actual.bf16_u16 == expected.bf16_u16) return true;
fprintf(stderr, "actual (%x) %g ; expected (%x) %g\n", actual.bf16_u16,
iree_math_bf16_to_f32(actual.bf16_u16), expected.bf16_u16,
iree_math_bf16_to_f32(expected.bf16_u16));
if (FLAG_require_exact_results) return false;
return fabsf(iree_math_bf16_to_f32(actual.bf16_u16) -
iree_math_bf16_to_f32(expected.bf16_u16)) <
Expand Down

0 comments on commit 16e4346

Please sign in to comment.