diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 68dc01d84de9..8bd33d1cd434 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -288,7 +288,7 @@ class VectorLayout { } template - void insertImplicit(SmallVector &vec, T value) const { + void insertImplicit(SmallVectorImpl &vec, T value) const { CHECK_GE(vec.size(), layout_rank()); switch (implicit_dim_) { case ImplicitDim::kNone: @@ -302,7 +302,7 @@ class VectorLayout { } template - void eraseImplicit(SmallVector &vec) const { + void eraseImplicit(SmallVectorImpl &vec) const { CHECK_GE(vec.size(), 2); switch (implicit_dim_) { case ImplicitDim::kNone: diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 20f6ec8d0a3f..b3ab5de06035 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5021,8 +5021,37 @@ FailureOr> relayout(RewriteContext &ctx, if (bitwidth != dst.bitwidth()) { return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } - const int packing = src.packing(); VectorType vty = v.getType(); + { + // Replication imposes a replication constraint on the *logical* value of + // the vector: When moving along a replicated axis, all elements must be + // equal. Note that when the axis is a singleton, there is effectively no + // added *logical* constraint. + // For example, a vector<2x2xf32> v with no implicit dims and layout offsets + // {*, 0} is expected to satisfy v[0, 0] == v[1, 0] and v[0, 1] == v[1, 1]. + // Relayout does not change the logical value of the vector. Any replication + // constraints in the result must be guaranteed by the source layout. + SmallVector src_offsets(ArrayRef(src.offsets())); + SmallVector dst_offsets(ArrayRef(dst.offsets())); + // Remove implicit dims to get offsets for trailing logical dims. + src.eraseImplicit(src_offsets); + dst.eraseImplicit(dst_offsets); + for (int i = dst_offsets.size(); i > 0; --i) { + const int64_t dim_size = *(vty.getShape().end() - i); + const bool dim_replicated_in_dst = !*(dst_offsets.end() - i); + // If the dim is untiled in the src layout, then there is no guarantee of + // replication, because we don't track replication for untiled dims. + const bool dim_replicated_in_src = + i <= src_offsets.size() && !*(src_offsets.end() - i); + if (dim_replicated_in_dst && !dim_replicated_in_src && dim_size != 1) { + return emitError(v.getLoc(), + "Invalid relayout: Non-singleton logical dimension is " + "replicated in destination but not in source for ") + << vty << ": " << src << " -> " << dst; + } + } + } + const int packing = src.packing(); // Save the original value of dst to use it at the end. It determines the // out_layout of the result of assemble. @@ -5054,8 +5083,8 @@ FailureOr> relayout(RewriteContext &ctx, /*use_implicit_shape=*/true) .getResult(); } - if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value() && - src.tilesPerVreg(target_shape) == 1) { + if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && + !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { // A fully replicated value is always easy to relayout // It would be nice to be able to assert this here, but given replicated // values our rules can introduce equivalent expressions. @@ -5258,25 +5287,29 @@ FailureOr> relayout(RewriteContext &ctx, // This drops the implicit second minor dimension. src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && - src.bitwidth() == 32 && src.offsets() == dst.offsets() && - src.offsets() == LayoutOffsets{0, 0} && src.tiling() == dst.tiling() && + src.bitwidth() == 32 && dst.offsets()[0] && + src.offsets()[1] == dst.offsets()[1] && src.tiling() == dst.tiling() && src.tiling() == std::array{8, 128}) { xla::Array src_tiles_retiled( dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - src_tiles_retiled.Each( - [&](const absl::Span idx, Value *tile) { - for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) { - SmallVector src_idx(idx.begin(), idx.end()); - src.insertImplicit(src_idx, 0); - auto second_minor_idx = idx.size() - 2; - src_idx[second_minor_idx] = 8 * idx[second_minor_idx] + dst_sl_idx; - if (src_idx[second_minor_idx] >= src_tiles.dim(second_minor_idx)) { - break; - } - *tile = copy_one_sublane(builder, src_tiles(src_idx), 0, *tile, - dst_sl_idx, target_shape); - } - }); + src_tiles_retiled.Each([&](const absl::Span idx, + Value *tile) { + const int64_t dst_2nd_minor_idx = idx.size() - 2; + SmallVector src_idx(idx.begin(), idx.end()); + src.insertImplicit(src_idx, 0); + const int dst_sl_start = + idx[dst_2nd_minor_idx] == 0 ? *dst.offsets()[0] : 0; + src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] + + dst_sl_start - *dst.offsets()[0]; + for (int dst_sl_idx = dst_sl_start; + dst_sl_idx < target_shape[0] && + src_idx[dst_2nd_minor_idx] < src_tiles.dim(dst_2nd_minor_idx); + ++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) { + *tile = copy_one_sublane(builder, src_tiles(src_idx), + src.offsets()[0].value_or(dst_sl_idx), *tile, + dst_sl_idx, target_shape); + } + }); src = dst; src_tiles = std::move(src_tiles_retiled); }