Skip to content

Commit

Permalink
[Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd mi…
Browse files Browse the repository at this point in the history
…nor offset

Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims

PiperOrigin-RevId: 655746766
  • Loading branch information
tlongeri authored and jax authors committed Jul 24, 2024
1 parent f1cfd99 commit 220ec2a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 21 deletions.
4 changes: 2 additions & 2 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class VectorLayout {
}

template <typename T>
void insertImplicit(SmallVector<T> &vec, T value) const {
void insertImplicit(SmallVectorImpl<T> &vec, T value) const {
CHECK_GE(vec.size(), layout_rank());
switch (implicit_dim_) {
case ImplicitDim::kNone:
Expand All @@ -302,7 +302,7 @@ class VectorLayout {
}

template <typename T>
void eraseImplicit(SmallVector<T> &vec) const {
void eraseImplicit(SmallVectorImpl<T> &vec) const {
CHECK_GE(vec.size(), 2);
switch (implicit_dim_) {
case ImplicitDim::kNone:
Expand Down
71 changes: 52 additions & 19 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5021,8 +5021,37 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
if (bitwidth != dst.bitwidth()) {
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
}
const int packing = src.packing();
VectorType vty = v.getType();
{
// Replication imposes a replication constraint on the *logical* value of
// the vector: When moving along a replicated axis, all elements must be
// equal. Note that when the axis is a singleton, there is effectively no
// added *logical* constraint.
// For example, a vector<2x2xf32> v with no implicit dims and layout offsets
// {*, 0} is expected to satisfy v[0, 0] == v[1, 0] and v[0, 1] == v[1, 1].
// Relayout does not change the logical value of the vector. Any replication
// constraints in the result must be guaranteed by the source layout.
SmallVector<LayoutOffset, 2> src_offsets(ArrayRef(src.offsets()));
SmallVector<LayoutOffset, 2> dst_offsets(ArrayRef(dst.offsets()));
// Remove implicit dims to get offsets for trailing logical dims.
src.eraseImplicit(src_offsets);
dst.eraseImplicit(dst_offsets);
for (int i = dst_offsets.size(); i > 0; --i) {
const int64_t dim_size = *(vty.getShape().end() - i);
const bool dim_replicated_in_dst = !*(dst_offsets.end() - i);
// If the dim is untiled in the src layout, then there is no guarantee of
// replication, because we don't track replication for untiled dims.
const bool dim_replicated_in_src =
i <= src_offsets.size() && !*(src_offsets.end() - i);
if (dim_replicated_in_dst && !dim_replicated_in_src && dim_size != 1) {
return emitError(v.getLoc(),
"Invalid relayout: Non-singleton logical dimension is "
"replicated in destination but not in source for ")
<< vty << ": " << src << " -> " << dst;
}
}
}
const int packing = src.packing();

// Save the original value of dst to use it at the end. It determines the
// out_layout of the result of assemble.
Expand Down Expand Up @@ -5054,8 +5083,8 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
/*use_implicit_shape=*/true)
.getResult();
}
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value() &&
src.tilesPerVreg(target_shape) == 1) {
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
!src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) {
// A fully replicated value is always easy to relayout
// It would be nice to be able to assert this here, but given replicated
// values our rules can introduce equivalent expressions.
Expand Down Expand Up @@ -5258,25 +5287,29 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
// This drops the implicit second minor dimension.
src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
src.bitwidth() == 32 && src.offsets() == dst.offsets() &&
src.offsets() == LayoutOffsets{0, 0} && src.tiling() == dst.tiling() &&
src.bitwidth() == 32 && dst.offsets()[0] &&
src.offsets()[1] == dst.offsets()[1] && src.tiling() == dst.tiling() &&
src.tiling() == std::array<int64_t, 2>{8, 128}) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
src_tiles_retiled.Each(
[&](const absl::Span<const int64_t> idx, Value *tile) {
for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src.insertImplicit<int64_t>(src_idx, 0);
auto second_minor_idx = idx.size() - 2;
src_idx[second_minor_idx] = 8 * idx[second_minor_idx] + dst_sl_idx;
if (src_idx[second_minor_idx] >= src_tiles.dim(second_minor_idx)) {
break;
}
*tile = copy_one_sublane(builder, src_tiles(src_idx), 0, *tile,
dst_sl_idx, target_shape);
}
});
src_tiles_retiled.Each([&](const absl::Span<const int64_t> idx,
Value *tile) {
const int64_t dst_2nd_minor_idx = idx.size() - 2;
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src.insertImplicit<int64_t>(src_idx, 0);
const int dst_sl_start =
idx[dst_2nd_minor_idx] == 0 ? *dst.offsets()[0] : 0;
src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] +
dst_sl_start - *dst.offsets()[0];
for (int dst_sl_idx = dst_sl_start;
dst_sl_idx < target_shape[0] &&
src_idx[dst_2nd_minor_idx] < src_tiles.dim(dst_2nd_minor_idx);
++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) {
*tile = copy_one_sublane(builder, src_tiles(src_idx),
src.offsets()[0].value_or(dst_sl_idx), *tile,
dst_sl_idx, target_shape);
}
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
}
Expand Down

0 comments on commit 220ec2a

Please sign in to comment.