diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 6ec64ca94810f..dbff2c795f30b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -232,6 +232,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
       return iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16bf16(params);
     case iree_uk_mmt4d_type_s8s8s32:
       return iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32(params);
+    case iree_uk_mmt4d_type_s16s16s32:
+      return 0;
+    case iree_uk_mmt4d_type_s16u4s32:
+      return 0;
     default:
       IREE_UK_ASSUME_UNREACHABLE;
       return 0;
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
index f31923e9611ef..09317c6de926f 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
@@ -292,6 +292,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
       return iree_uk_mmt4d_select_tile_func_x86_64_bf16bf16bf16(params);
     case iree_uk_mmt4d_type_s8s8s32:
       return iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32(params);
+    case iree_uk_mmt4d_type_s16s16s32:
+      return 0;
+    case iree_uk_mmt4d_type_s16u4s32:
+      return 0;
     default:
       IREE_UK_ASSUME_UNREACHABLE;
       return 0;
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index b196d856776ab..8283cda572848 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -461,14 +461,17 @@ enum {
   IREE_UK_TYPE_OPAQUE_16 = IREE_UK_TYPE_CATEGORY_OPAQUE | 4,
   IREE_UK_TYPE_OPAQUE_32 = IREE_UK_TYPE_CATEGORY_OPAQUE | 5,
   IREE_UK_TYPE_OPAQUE_64 = IREE_UK_TYPE_CATEGORY_OPAQUE | 6,
+  IREE_UK_TYPE_INT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 2,
   IREE_UK_TYPE_INT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 3,
   IREE_UK_TYPE_INT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 4,
   IREE_UK_TYPE_INT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 5,
   IREE_UK_TYPE_INT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 6,
+  IREE_UK_TYPE_SINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 2,
   IREE_UK_TYPE_SINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 3,
   IREE_UK_TYPE_SINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 4,
   IREE_UK_TYPE_SINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 5,
   IREE_UK_TYPE_SINT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 6,
+  IREE_UK_TYPE_UINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 2,
   IREE_UK_TYPE_UINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 3,
   IREE_UK_TYPE_UINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 4,
   IREE_UK_TYPE_UINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 5,
@@ -531,7 +534,9 @@ static inline iree_uk_uint8_t iree_uk_integer_type_as_unsigned(
 // The current implementation might return a negative value, but don't rely on
 // that.
 static inline int iree_uk_type_size_log2(iree_uk_type_t t) {
-  return iree_uk_type_bit_count_log2(t) - 3;
+  int bit_count_log2 = iree_uk_type_bit_count_log2(t);
+  IREE_UK_ASSERT(bit_count_log2 >= 3);
+  return bit_count_log2 - 3;
 }
 
 static inline int iree_uk_type_bit_count(iree_uk_type_t t) {
@@ -545,6 +550,21 @@ static inline int iree_uk_type_size(iree_uk_type_t t) {
   return 1 << iree_uk_type_size_log2(t);
 }
 
+// Helper to correctly convert a bit-size to a byte-size, rounding up if the
+// bit-size is not a multiple of 8.
+static inline iree_uk_index_t iree_uk_bits_to_bytes_rounding_up(
+    iree_uk_index_t bits) {
+  return (bits + 7) / 8;
+}
+
+// Helper to correctly convert a bit-size to a byte-size, asserting that the
+// bit-size is a multiple of 8.
+static inline iree_uk_index_t iree_uk_bits_to_bytes_exact(
+    iree_uk_index_t bits) {
+  IREE_UK_ASSERT(!(bits % 8));
+  return bits / 8;
+}
+
 //===----------------------------------------------------------------------===//
 // Tuples of types, packed ("tied") into a word.
 //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/builtins/ukernel/exported_bits.h b/runtime/src/iree/builtins/ukernel/exported_bits.h
index e2dfd6af1d33b..c027c5b3b4e10 100644
--- a/runtime/src/iree/builtins/ukernel/exported_bits.h
+++ b/runtime/src/iree/builtins/ukernel/exported_bits.h
@@ -49,6 +49,9 @@
 #define IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 0x04
 #define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 0x05
 #define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16 0x06
+#define IREE_UK_FLAG_MMT4D_TYPE_S16S16S32 0x07
+#define IREE_UK_FLAG_MMT4D_TYPE_S16U4S32 0x08
+#define IREE_UK_FLAG_MMT4D_TYPE_END 0x09
 
 // bit flags
 #define IREE_UK_FLAG_MMT4D_ACCUMULATE 0x100
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index a626a8c7da08e..e775ca1265145 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -15,12 +15,7 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) {
       IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS;
   IREE_UK_ASSERT(!(params->flags & ~allflags));
   iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK;
-  IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 ||
-                 flags_type == IREE_UK_FLAG_MMT4D_TYPE_S8S8S32 ||
-                 flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F32 ||
-                 flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 ||
-                 flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 ||
-                 flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16);
+  IREE_UK_ASSERT(flags_type < IREE_UK_FLAG_MMT4D_TYPE_END);
   // Some implementations may wish to avoid supporting absurdly wide types. For
   // instance, K is the innermost (i.e. hottest) loop bound, so some 32bit
   // targets may benefit from K being int32, not int64. We still let K be of
@@ -33,8 +28,22 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) {
   IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->M0, 15));
   IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->N0, 15));
   IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->K0, 15));
-  // Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile.
+
+  // Requirements on sub-byte element type cases
+  // - Ensure that the output type is not sub-byte.
   iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
+  IREE_UK_ASSERT(iree_uk_type_bit_count(iree_uk_mmt4d_out_type(mmt4d_type)) >=
+                 8);
+  // - Ensure that (K0 * {LHS,RHS} element bits) is a multiple of 8 bits.
+  int lhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type));
+  int rhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type));
+  IREE_UK_ASSERT(!((params->K0 * lhs_bits) % 8));
+  IREE_UK_ASSERT(!((params->K0 * rhs_bits) % 8));
+  // - Ensure that {LHS,RHS} strides are multiples of 8 bits.
+  IREE_UK_ASSERT(!((params->lhs_stride0 * lhs_bits) % 8));
+  IREE_UK_ASSERT(!((params->rhs_stride0 * rhs_bits) % 8));
+
+  // Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile.
   IREE_UK_ASSERT(params->M0 * params->N0 *
                      iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type)) <=
                  iree_uk_mmt4d_tile_generic_max_bytes);
@@ -56,18 +65,24 @@ static void iree_uk_mmt4d_using_tile_func(const iree_uk_mmt4d_params_t* params,
   const iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
   const iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
   const iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
-  const iree_uk_int16_t lhs_elem_size_log2 = iree_uk_type_size_log2(lhs_type);
-  const iree_uk_int16_t rhs_elem_size_log2 = iree_uk_type_size_log2(rhs_type);
+  const iree_uk_int16_t lhs_elem_bits_log2 =
+      iree_uk_type_bit_count_log2(lhs_type);
+  const iree_uk_int16_t rhs_elem_bits_log2 =
+      iree_uk_type_bit_count_log2(rhs_type);
   const iree_uk_int16_t out_elem_size_log2 = iree_uk_type_size_log2(out_type);
   char* out_tile_row =
       (char*)params->out_buffer + (params->out_offset << out_elem_size_log2);
-  const char* lhs_panel = (const char*)params->lhs_buffer +
-                          (params->lhs_offset << lhs_elem_size_log2);
-  const char* rhs_panel_start = (const char*)params->rhs_buffer +
-                                (params->rhs_offset << rhs_elem_size_log2);
+  const char* lhs_panel =
+      (const char*)params->lhs_buffer +
+      iree_uk_bits_to_bytes_exact(params->lhs_offset << lhs_elem_bits_log2);
+  const char* rhs_panel_start =
+      (const char*)params->rhs_buffer +
+      iree_uk_bits_to_bytes_exact(params->rhs_offset << rhs_elem_bits_log2);
   iree_uk_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
-  iree_uk_index_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
-  iree_uk_index_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
+  iree_uk_index_t lhs_panel_stride =
+      iree_uk_bits_to_bytes_exact(params->lhs_stride0 << lhs_elem_bits_log2);
+  iree_uk_index_t rhs_panel_stride =
+      iree_uk_bits_to_bytes_exact(params->rhs_stride0 << rhs_elem_bits_log2);
   iree_uk_index_t out_stride = params->out_stride0 << out_elem_size_log2;
   for (iree_uk_int32_t i = 0; i < M; ++i) {
     char* out_tile = out_tile_row;
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
index ebdd673194fc0..1b743c2593a1d 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
@@ -14,6 +14,10 @@ typedef enum iree_uk_mmt4d_type_t {
       IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_32, FLOAT_32, FLOAT_32),
   iree_uk_mmt4d_type_s8s8s32 =
       IREE_UK_TIE_3_TYPES_LITERAL(SINT_8, SINT_8, SINT_32),
+  iree_uk_mmt4d_type_s16s16s32 =
+      IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, SINT_16, SINT_32),
+  iree_uk_mmt4d_type_s16u4s32 =
+      IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, UINT_4, SINT_32),
   iree_uk_mmt4d_type_f16f16f32 =
       IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_16, FLOAT_16, FLOAT_32),
   iree_uk_mmt4d_type_f16f16f16 =
@@ -30,6 +34,10 @@ static inline iree_uk_mmt4d_type_t iree_uk_mmt4d_type(iree_uk_uint32_t flags) {
       return iree_uk_mmt4d_type_f32f32f32;
     case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32:
       return iree_uk_mmt4d_type_s8s8s32;
+    case IREE_UK_FLAG_MMT4D_TYPE_S16S16S32:
+      return iree_uk_mmt4d_type_s16s16s32;
+    case IREE_UK_FLAG_MMT4D_TYPE_S16U4S32:
+      return iree_uk_mmt4d_type_s16u4s32;
     case IREE_UK_FLAG_MMT4D_TYPE_F16F16F32:
       return iree_uk_mmt4d_type_f16f16f32;
     case IREE_UK_FLAG_MMT4D_TYPE_F16F16F16:
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
index 622ae89863bd4..6604d3baedb11 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
@@ -6,7 +6,7 @@
 
 #include "iree/builtins/ukernel/mmt4d_internal.h"
 
-// Generic implementation of matmul tile, i8*i8->i32 case.
+// Generic implementation of matmul tile, s8*s8->s32 case.
 static void iree_uk_mmt4d_tile_s8s8s32_generic(
     void* out_tile_untyped, const void* lhs_panel_untyped,
     const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
@@ -41,6 +41,84 @@ static void iree_uk_mmt4d_tile_s8s8s32_generic(
   for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
 }
 
+// Generic implementation of matmul tile, s16*s16->s32 case.
+static void iree_uk_mmt4d_tile_s16s16s32_generic(
+    void* out_tile_untyped, const void* lhs_panel_untyped,
+    const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
+  iree_uk_int32_t* out_tile = out_tile_untyped;
+  const iree_uk_int16_t* lhs_panel = lhs_panel_untyped;
+  const iree_uk_int16_t* rhs_panel = rhs_panel_untyped;
+  iree_uk_int16_t M0 = params->M0;
+  iree_uk_int16_t N0 = params->N0;
+  iree_uk_int16_t K0 = params->K0;
+  // Initialize the local accumulator tile.
+  iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
+  if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+    for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
+  } else {
+    for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
+  }
+  // Accumulation loop.
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+      for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+        for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
+          iree_uk_int32_t lhs_i32 = lhs_panel[i0 * K0 + k0];
+          iree_uk_int32_t rhs_i32 = rhs_panel[j0 * K0 + k0];
+          acc[i0 * N0 + j0] += lhs_i32 * rhs_i32;
+        }
+      }
+    }
+    lhs_panel += M0 * K0;
+    rhs_panel += N0 * K0;
+  }
+  // Store the local accumulator tile to the destination.
+  for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
+}
+
+// Generic implementation of matmul tile, s16*u4->s32 case.
+static void iree_uk_mmt4d_tile_s16u4s32_generic(
+    void* out_tile_untyped, const void* lhs_panel_untyped,
+    const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
+  iree_uk_int32_t* out_tile = out_tile_untyped;
+  const iree_uk_int16_t* lhs_panel = lhs_panel_untyped;
+  const iree_uk_uint8_t* rhs_panel = rhs_panel_untyped;
+  iree_uk_int16_t M0 = params->M0;
+  iree_uk_int16_t N0 = params->N0;
+  iree_uk_int16_t K0 = params->K0;
+  // K0 must be even.
+  IREE_UK_ASSERT(!(K0 % 2));
+  iree_uk_int16_t K0half = K0 / 2;
+  // Initialize the local accumulator tile.
+  iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
+  if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+    for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
+  } else {
+    for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
+  }
+  // Accumulation loop.
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+      for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+        // As K0 must be even, we 2x-unroll the K0 loop, writing a 2D dot
+        // product.
+        for (iree_uk_index_t k0h = 0; k0h < K0half; ++k0h) {
+          iree_uk_int32_t lhs_0 = lhs_panel[i0 * K0 + 2 * k0h];
+          iree_uk_int32_t lhs_1 = lhs_panel[i0 * K0 + 2 * k0h + 1];
+          iree_uk_uint8_t rhs_byte = rhs_panel[j0 * K0half + k0h];
+          iree_uk_int32_t rhs_0 = rhs_byte & 0xf;
+          iree_uk_int32_t rhs_1 = rhs_byte >> 4;
+          acc[i0 * N0 + j0] += lhs_0 * rhs_0 + lhs_1 * rhs_1;
+        }
+      }
+    }
+    lhs_panel += M0 * K0;
+    rhs_panel += N0 * K0half;
+  }
+  // Store the local accumulator tile to the destination.
+  for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
+}
+
 // Generic implementation of matmul tile, f32*f32->f32 case.
 static void iree_uk_mmt4d_tile_f32f32f32_generic(
     void* out_tile_untyped, const void* lhs_panel_untyped,
@@ -227,6 +305,10 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
       return iree_uk_mmt4d_tile_f32f32f32_generic;
     case iree_uk_mmt4d_type_s8s8s32:
       return iree_uk_mmt4d_tile_s8s8s32_generic;
+    case iree_uk_mmt4d_type_s16s16s32:
+      return iree_uk_mmt4d_tile_s16s16s32_generic;
+    case iree_uk_mmt4d_type_s16u4s32:
+      return iree_uk_mmt4d_tile_s16u4s32_generic;
     case iree_uk_mmt4d_type_f16f16f32:
       return iree_uk_mmt4d_tile_f16f16f32_generic;
     case iree_uk_mmt4d_type_f16f16f16:
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 84ae68f9e07bd..21d1ea7464f99 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -137,12 +137,47 @@ static void iree_mmt4d_reference_innerloop_s8s8s32(
   *out_ptr = acc;
 }
 
+static void iree_mmt4d_reference_innerloop_s16s16s32(
+    int32_t* out_ptr, const int16_t* lhs_ptr, const int16_t* rhs_ptr,
+    const iree_uk_mmt4d_params_t* params) {
+  int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0;
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
+      int32_t lhs_i32 = lhs_ptr[k * params->M0 * params->K0 + k0];
+      int32_t rhs_i32 = rhs_ptr[k * params->N0 * params->K0 + k0];
+      acc += lhs_i32 * rhs_i32;
+    }
+  }
+  *out_ptr = acc;
+}
+
+static void iree_mmt4d_reference_innerloop_s16u4s32(
+    int32_t* out_ptr, const int16_t* lhs_ptr, const uint8_t* rhs_ptr,
+    const iree_uk_mmt4d_params_t* params) {
+  // K0 must be even.
+  IREE_UK_ASSERT(!(params->K0 % 2));
+  iree_uk_int16_t K0half = params->K0 / 2;
+  int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0;
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    // As K0 must be even, we 2x-unroll the K0 loop, writing a 2D dot product.
+    for (iree_uk_index_t k0h = 0; k0h < K0half; ++k0h) {
+      int32_t lhs_0 = lhs_ptr[k * params->M0 * params->K0 + 2 * k0h];
+      int32_t lhs_1 = lhs_ptr[k * params->M0 * params->K0 + 2 * k0h + 1];
+      uint8_t rhs_byte = rhs_ptr[k * params->N0 * K0half + k0h];
+      int32_t rhs_0 = rhs_byte & 0xf;
+      int32_t rhs_1 = rhs_byte >> 4;
+      acc += lhs_0 * rhs_0 + lhs_1 * rhs_1;
+    }
+  }
+  *out_ptr = acc;
+}
+
 static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
   iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
-  iree_uk_index_t lhs_elem_size =
-      iree_uk_type_size(iree_uk_mmt4d_lhs_type(mmt4d_type));
-  iree_uk_index_t rhs_elem_size =
-      iree_uk_type_size(iree_uk_mmt4d_rhs_type(mmt4d_type));
+  iree_uk_index_t lhs_elem_bits =
+      iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type));
+  iree_uk_index_t rhs_elem_bits =
+      iree_uk_type_bit_count(iree_uk_mmt4d_rhs_type(mmt4d_type));
   iree_uk_index_t out_elem_size =
       iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type));
   for (iree_uk_index_t i = 0; i < params->M; ++i) {
@@ -153,18 +188,22 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
                                out_elem_size;
       const void* lhs_panel_ptr =
           ((const char*)params->lhs_buffer) +
-          (params->lhs_offset + i * params->lhs_stride0) * lhs_elem_size;
+          iree_uk_bits_to_bytes_exact(
+              (params->lhs_offset + i * params->lhs_stride0) * lhs_elem_bits);
       const void* rhs_panel_ptr =
           ((const char*)params->rhs_buffer) +
-          (params->rhs_offset + j * params->rhs_stride0) * rhs_elem_size;
+          iree_uk_bits_to_bytes_exact(
+              (params->rhs_offset + j * params->rhs_stride0) * rhs_elem_bits);
       for (iree_uk_index_t i0 = 0; i0 < params->M0; ++i0) {
         for (iree_uk_index_t j0 = 0; j0 < params->N0; ++j0) {
           void* out_ptr =
               ((char*)out_tile_ptr) + (i0 * params->N0 + j0) * out_elem_size;
           const void* lhs_ptr =
-              ((char*)lhs_panel_ptr) + i0 * params->K0 * lhs_elem_size;
+              ((char*)lhs_panel_ptr) +
+              iree_uk_bits_to_bytes_exact(i0 * params->K0 * lhs_elem_bits);
           const void* rhs_ptr =
-              ((char*)rhs_panel_ptr) + j0 * params->K0 * rhs_elem_size;
+              ((char*)rhs_panel_ptr) +
+              iree_uk_bits_to_bytes_exact(j0 * params->K0 * rhs_elem_bits);
           switch (params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK) {
             case IREE_UK_FLAG_MMT4D_TYPE_F32F32F32:
               iree_mmt4d_reference_innerloop_f32f32f32(
@@ -196,6 +235,16 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
                   (int32_t*)out_ptr, (const int8_t*)lhs_ptr,
                   (const int8_t*)rhs_ptr, params);
               break;
+            case IREE_UK_FLAG_MMT4D_TYPE_S16S16S32:
+              iree_mmt4d_reference_innerloop_s16s16s32(
+                  (int32_t*)out_ptr, (const int16_t*)lhs_ptr,
+                  (const int16_t*)rhs_ptr, params);
+              break;
+            case IREE_UK_FLAG_MMT4D_TYPE_S16U4S32:
+              iree_mmt4d_reference_innerloop_s16u4s32(
+                  (int32_t*)out_ptr, (const int16_t*)lhs_ptr,
+                  (const uint8_t*)rhs_ptr, params);
+              break;
             default:
               IREE_UK_ASSERT(false && "unhandled type");
           }
@@ -206,23 +255,47 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
   }
 }
 
+static iree_uk_index_t iree_uk_test_round_up_to_ensure_multiple_of_8_bits(
+    iree_uk_index_t index, iree_uk_type_t type) {
+  // Honor the requirement that strides should be multiples of 8 bits.
+  while ((index << iree_uk_type_bit_count_log2(type)) & 7) {
+    ++index;
+  }
+  return index;
+}
+
+static iree_uk_index_t iree_uk_test_random_stride(
+    iree_uk_index_t min_stride, iree_uk_type_t type,
+    iree_uk_random_engine_t* engine) {
+  // Randomly make strides either tight or not to exercise all cases.
+  iree_uk_index_t stride = min_stride + iree_uk_random_engine_get_0_1(engine);
+  return iree_uk_test_round_up_to_ensure_multiple_of_8_bits(stride, type);
+}
+
+static iree_uk_index_t iree_uk_test_random_offset(
+    iree_uk_type_t type, iree_uk_random_engine_t* engine) {
+  // Randomly make strides either tight or not to exercise all cases.
+  iree_uk_index_t stride = iree_uk_random_engine_get_0_1(engine);
+  return iree_uk_test_round_up_to_ensure_multiple_of_8_bits(stride, type);
+}
+
 static void iree_uk_test_mmt4d_for_shape_params(
     iree_uk_test_t* test, const iree_uk_mmt4d_params_t* src_params) {
   iree_uk_mmt4d_params_t params;
   memcpy(&params, src_params, sizeof params);
-  // Populate strides first - we need them below to compute buffer lengths.
-  // Randomly make strides either tight or not to exercise all cases.
-  iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
-  params.lhs_stride0 =
-      params.K * params.M0 * params.K0 + iree_uk_random_engine_get_0_1(engine);
-  params.rhs_stride0 =
-      params.K * params.N0 * params.K0 + iree_uk_random_engine_get_0_1(engine);
-  params.out_stride0 =
-      params.N * params.M0 * params.N0 + iree_uk_random_engine_get_0_1(engine);
   iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params.flags);
   iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
   iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
   iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
+  // Populate strides first - we need them below to compute buffer lengths.
+  // Randomly make strides either tight or not to exercise all cases.
+  iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
+  params.lhs_stride0 = iree_uk_test_random_stride(
+      params.K * params.M0 * params.K0, lhs_type, engine);
+  params.rhs_stride0 = iree_uk_test_random_stride(
+      params.K * params.N0 * params.K0, rhs_type, engine);
+  params.out_stride0 = iree_uk_test_random_stride(
+      params.N * params.M0 * params.N0, out_type, engine);
   iree_uk_index_t lhs_buffer_size =
       iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0);
   iree_uk_index_t rhs_buffer_size =
@@ -231,13 +304,17 @@ static void iree_uk_test_mmt4d_for_shape_params(
   void* rhs_buffer = malloc(rhs_buffer_size);
   iree_uk_write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
   iree_uk_write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
-  params.lhs_offset = iree_uk_random_engine_get_0_65535(engine);
-  params.rhs_offset = iree_uk_random_engine_get_0_65535(engine);
-  params.out_offset = iree_uk_random_engine_get_0_65535(engine);
-  params.lhs_buffer = (const char*)lhs_buffer -
-                      (params.lhs_offset * iree_uk_type_size(lhs_type));
-  params.rhs_buffer = (const char*)rhs_buffer -
-                      (params.rhs_offset * iree_uk_type_size(rhs_type));
+  params.lhs_offset = iree_uk_test_random_offset(lhs_type, engine);
+  params.rhs_offset = iree_uk_test_random_offset(rhs_type, engine);
+  params.out_offset = iree_uk_test_random_offset(out_type, engine);
+  params.lhs_buffer =
+      (const char*)lhs_buffer -
+      iree_uk_bits_to_bytes_exact(params.lhs_offset
+                                  << iree_uk_type_bit_count_log2(lhs_type));
+  params.rhs_buffer =
+      (const char*)rhs_buffer -
+      iree_uk_bits_to_bytes_exact(params.rhs_offset
+                                  << iree_uk_type_bit_count_log2(rhs_type));
 
   iree_uk_mmt4d_params_t reference_params;
   memcpy(&reference_params, &params, sizeof params);
@@ -250,14 +327,17 @@ static void iree_uk_test_mmt4d_for_shape_params(
   memcpy(reference_out_buffer, init_out_buffer, out_buffer_size);
   reference_params.out_buffer =
       (char*)reference_out_buffer -
-      (params.out_offset * iree_uk_type_size(out_type));
+      iree_uk_bits_to_bytes_exact(params.out_offset
+                                  << iree_uk_type_bit_count_log2(out_type));
 
   iree_uk_mmt4d_params_t actual_params;
   memcpy(&actual_params, &params, sizeof params);
   void* actual_out_buffer = malloc(out_buffer_size);
   memcpy(actual_out_buffer, init_out_buffer, out_buffer_size);
-  actual_params.out_buffer = (char*)actual_out_buffer -
-                             (params.out_offset * iree_uk_type_size(out_type));
+  actual_params.out_buffer =
+      (char*)actual_out_buffer -
+      iree_uk_bits_to_bytes_exact(params.out_offset
+                                  << iree_uk_type_bit_count_log2(out_type));
 
   iree_mmt4d_reference(&reference_params);
   iree_uk_mmt4d(&actual_params);
@@ -376,6 +456,8 @@ int main(int argc, char** argv) {
   // in a power-of-two assumption
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 3, 5, 7, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 9, 6, 3, "");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 7, 3, 6, "");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 5, 3, 2, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 4, 6, 5, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 3, 5, 8, "");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 11, 4, 1, "");
diff --git a/runtime/src/iree/builtins/ukernel/tools/util.c b/runtime/src/iree/builtins/ukernel/tools/util.c
index 2f94964d7bb86..8c8b545abf912 100644
--- a/runtime/src/iree/builtins/ukernel/tools/util.c
+++ b/runtime/src/iree/builtins/ukernel/tools/util.c
@@ -33,20 +33,27 @@ void iree_uk_assert_fail(const char* file, int line, const char* function,
 iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type,
                                          iree_uk_index_t size0,
                                          iree_uk_index_t stride0) {
-  // Just for testing purposes, so it's OK to overestimate size.
-  return size0 * stride0 << iree_uk_type_size_log2(type);
+  // As we require strides to be multiples of 8 bits, the stride value in bytes
+  // is exact.
+  return size0 * iree_uk_bits_to_bytes_exact(
+                     stride0 << iree_uk_type_bit_count_log2(type));
 }
 
 bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2,
                               iree_uk_type_t type, iree_uk_index_t size0,
                               iree_uk_index_t size1, iree_uk_index_t stride0) {
-  iree_uk_index_t elem_size = iree_uk_type_size(type);
+  // Sizes don't have to be multiples of 8 bits.
+  iree_uk_index_t size1_bytes = iree_uk_bits_to_bytes_rounding_up(
+      size1 << iree_uk_type_bit_count_log2(type));
+  // Strides are required to be multiples of 8 bits.
+  iree_uk_index_t stride0_bytes =
+      iree_uk_bits_to_bytes_exact(stride0 << iree_uk_type_bit_count_log2(type));
   const char* buf1_ptr = buf1;
   const char* buf2_ptr = buf2;
   for (iree_uk_index_t i0 = 0; i0 < size0; ++i0) {
-    if (memcmp(buf1_ptr, buf2_ptr, elem_size * size1)) return false;
-    buf1_ptr += elem_size * stride0;
-    buf2_ptr += elem_size * stride0;
+    if (memcmp(buf1_ptr, buf2_ptr, size1_bytes)) return false;
+    buf1_ptr += stride0_bytes;
+    buf2_ptr += stride0_bytes;
   }
   return true;
 }
@@ -73,6 +80,11 @@ int iree_uk_random_engine_get_0_65535(iree_uk_random_engine_t* e) {
   return (v >> 8) & 0xffff;
 }
 
+int iree_uk_random_engine_get_0_255(iree_uk_random_engine_t* e) {
+  int v = iree_uk_random_engine_get_0_65535(e);
+  return v & 0xff;
+}
+
 int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e) {
   int v = iree_uk_random_engine_get_0_65535(e);
   return v & 1;
@@ -97,6 +109,16 @@ void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
     iree_uk_write_random_buffer(buffer, size_in_bytes, resolved_type, engine);
     return;
   }
+  // Special-case sub-byte-size integer types. Due to their narrow range, we
+  // want to generate values over their entire range, and then it's down to
+  // just generating random bytes.
+  if (iree_uk_type_is_integer(type) && iree_uk_type_bit_count(type) < 8) {
+    for (iree_uk_index_t i = 0; i < size_in_bytes; ++i) {
+      ((uint8_t*)buffer)[i] = iree_uk_random_engine_get_0_255(engine);
+    }
+    return;
+  }
+  // All other element types.
   iree_uk_index_t elem_size = iree_uk_type_size(type);
   iree_uk_index_t size_in_elems = size_in_bytes / elem_size;
   for (iree_uk_index_t i = 0; i < size_in_elems; ++i) {