Skip to content

Commit

Permalink
[Pallas][Mosaic] Relax dynamic index on 2nd minor dim in load/store.
Browse files Browse the repository at this point in the history
We support any dynamic index on 2nd minor dim in either of the cases:
1. The minormost dim size of a unsliced memref matches VREG lane count.
2. Load/store one row on the second minormost dim, which triggers implicit strided load/store.

Note: For the default cases which can not skip the alignment check, we still use dynamic slice + static load/store solution to reduce scalar core work. We should figure out a way to optimize this in all cases.
PiperOrigin-RevId: 648771794
  • Loading branch information
bythew3i authored and jax authors committed Jul 2, 2024
1 parent cad751f commit 484d09f
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 91 deletions.
192 changes: 139 additions & 53 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const VectorLayout &layout_out = *layouts_out.front();
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto load_op = cast<vector::LoadOp>(op);
const auto memref_ty = cast<MemRefType>(load_op.getBase().getType());
const auto memref_ty = getMemRefType(load_op.getBase());
const auto vty = cast<VectorType>(load_op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(
VectorType target_ty,
Expand Down Expand Up @@ -2772,36 +2772,80 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
}
// TODO(apaszke): Check that loads are from vmem!

int tiled_dims = is_1d ? 1 : 2;
Value base_addr;
SmallVector<int64_t> base_indices;
if (auto const_indices =
getIntConstsFromOperandRange(load_op.getIndices(), /*silent=*/true);
succeeded(const_indices)) {
base_addr = load_op.getBase();
base_indices = std::move(*const_indices);
} else {
auto slice_result =
sliceRef(builder, load_op.getBase(), load_op.getVectorType().getShape(),
load_op.getIndices(),
ArrayRef<int64_t>(memref_tiling).take_back(tiled_dims));
if (failed(slice_result)) {
return failure();
bool can_support_unaligned_dynamic_index = false;
bool must_support_unaligned_dynamic_index = false;
if (load_op.getIndices().size() > 1) {
auto second_minor_idx = load_op.getIndices().take_back(2)[0];
if (failed(getIntConst(second_minor_idx, /*silent=*/true)) &&
!isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) {
must_support_unaligned_dynamic_index = true;
}
std::tie(base_addr, base_indices) = *slice_result;
}
auto tile_base_idxs = ArrayRef<int64_t>(base_indices).take_back(tiled_dims);
auto batch_base_idxs = ArrayRef<int64_t>(base_indices).drop_back(tiled_dims);

const SmallVector<int64_t> implicit_shape =
layout_out.implicitShape(vty.getShape());
const int64_t ss = implicit_shape[implicit_shape.size() - 2];
int64_t sublane_stride = 1;
// Handle special patterns that allow us to support more flexible loads.
if (layout_out.bitwidth() == 32 &&
layout_out.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
ss == 1) {
// Loading a single row on the 2nd minor dim into the (1, 128) layout. We
// can use sublane striding to perform the relayout as part of the load.
sublane_stride = memref_tiling[0];
can_support_unaligned_dynamic_index = true;
} else {
// Otherwise, if the memref has a short last dimension and is contiguous
// all the tiled layouts become equivalent, so we can handle unaligned
// dynamic indices without any special case.
auto mem_layout = dyn_cast<TiledLayoutAttr>(memref_ty.getLayout());
if (!mem_layout) {
return op.emitOpError("Expected a tiled memref");
}
auto tile_strides = mem_layout.getTileStrides();
if (memref_ty.getShape().back() == ctx.target_shape[1] &&
tile_strides.take_back(2) == ArrayRef<int64_t>{1, 1}) {
can_support_unaligned_dynamic_index = true;
}
}

auto add_idx = [&](const Value &v, int64_t d) -> Value {
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
return IdxConst(cst.value() + d, builder, op.getLoc());
}
return builder.create<arith::AddIOp>(v, IdxConst(d, builder, op.getLoc()));
};

int tiled_dims = is_1d ? 1 : 2;
Value base_addr = load_op.getBase();
SmallVector<Value, 4> base_indices = load_op.getIndices();

if (must_support_unaligned_dynamic_index) {
if (!can_support_unaligned_dynamic_index) {
return op.emitOpError(
"Not implemented: dynamic load with unaligned indices");
}
} else {
// Convert dynamic load to dynamic slice + static load. This saves us a
// bunch of scalar core work.
auto slice_result =
sliceRef(builder, load_op.getBase(), load_op.getVectorType().getShape(),
load_op.getIndices(),
ArrayRef<int64_t>(memref_tiling).take_back(tiled_dims));
if (failed(slice_result)) {
return failure();
}
base_addr = slice_result->first;
CHECK_EQ(slice_result->second.size(), base_indices.size());
for (int i = 0; i < base_indices.size(); ++i) {
base_indices[i] = IdxConst(slice_result->second[i], builder, op.getLoc());
}
}

// TODO(jevinjiang): ideally we should update the base addr and use static
// indices even for the cases that can skip alignment check. This can save us
// a bunch of scalar core work.
auto tile_base_idxs = ArrayRef<Value>(base_indices).take_back(tiled_dims);
auto batch_base_idxs = ArrayRef<Value>(base_indices).drop_back(tiled_dims);
const LayoutOffsets offsets = layout_out.offsets();
AffineMap load_map;
arith::ConstantOp padding;
Expand Down Expand Up @@ -2841,21 +2885,18 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
CHECK_EQ(num_dims, tile_idxs.size());
SmallVector<Value> idxs(tile_idxs.size());
for (int64_t i = 0; i < num_batch_dims; ++i) {
idxs[i] = IdxConst(batch_base_idxs[i] + tile_idxs[i], builder,
load_op->getLoc());
idxs[i] = add_idx(batch_base_idxs[i], tile_idxs[i]);
}
const int64_t base_l = tile_base_idxs.back();
const auto base_l = tile_base_idxs.back();
const int64_t lidx = tile_idxs[num_dims - 1];
idxs[num_dims - 1] =
IdxConst(base_l + lidx * vreg_slice[1] - *offsets[1], builder,
load_op->getLoc());
add_idx(base_l, lidx * vreg_slice[1] - offsets[1].value_or(0));
if (!is_1d) {
CHECK_EQ(tile_base_idxs.size(), 2);
const int64_t base_s = tile_base_idxs.front();
const auto base_s = tile_base_idxs.front();
const int64_t sidx = tile_idxs[num_dims - 2];
idxs[num_dims - 2] =
IdxConst(base_s + sidx * vreg_slice[0] - offsets[0].value_or(0),
builder, load_op->getLoc());
add_idx(base_s, sidx * vreg_slice[0] - offsets[0].value_or(0));
}
TPU_ASSERT_OP(tile_idxs[num_dims - 1] + ctx.target_shape[1] <=
memref_ty.getShape()[num_dims - 1]);
Expand Down Expand Up @@ -3919,6 +3960,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
vector::StoreOp store_op = cast<vector::StoreOp>(op);
const VectorType ty = store_op.getValueToStore().getType();
const VectorLayout &to_store_layout = *layouts_in.front();
const auto memref_ty = getMemRefType(store_op.getBase());
if (!ty.getRank()) {
return op.emitOpError("Not implemented: scalar stores to vmem");
}
Expand All @@ -3944,34 +3986,87 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
}
}

bool can_support_unaligned_dynamic_index = false;
bool must_support_unaligned_dynamic_index = false;
if (store_op.getIndices().size() > 1) {
auto second_minor_idx = store_op.getIndices().take_back(2)[0];
if (failed(getIntConst(second_minor_idx, /*silent=*/true)) &&
!isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) {
must_support_unaligned_dynamic_index = true;
}
}
int64_t sublane_stride = 1;
// Handle special patterns that allow us to support more flexible loads.
if (to_store_layout.bitwidth() == 32 &&
to_store_layout.tiling() == Tiling{1, ctx.target_shape[1]}) {
// Storing a single row on the 2nd minor dim from the (1, 128) layout. We
// can use sublane striding to perform the relayout as part of the store.
// The stride of store should be the number of sublanes in memref tile when
// store a single sublane.
sublane_stride = memref_tiling[0];
can_support_unaligned_dynamic_index = true;
} else {
// Otherwise, if the memref has a short last dimension and is contiguous
// all the tiled layouts become equivalent, so we can handle unaligned
// dynamic indices without any special case.
auto mem_layout = dyn_cast<TiledLayoutAttr>(memref_ty.getLayout());
if (!mem_layout) {
return op.emitOpError("Expected a tiled memref");
}
auto tile_strides = mem_layout.getTileStrides();
if (memref_ty.getShape().back() == ctx.target_shape[1] &&
tile_strides.take_back(2) == ArrayRef<int64_t>{1, 1}) {
can_support_unaligned_dynamic_index = true;
}
}

auto add_idx = [&](const Value &v, int64_t d) -> Value {
if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) {
return IdxConst(cst.value() + d, builder, op.getLoc());
}
return builder.create<arith::AddIOp>(v, IdxConst(d, builder, op.getLoc()));
};

int tiled_dims = is_1d ? 1 : 2;
Value base_addr;
SmallVector<int64_t> base_indices;
if (auto const_indices =
getIntConstsFromOperandRange(store_op.getIndices(), /*silent=*/true);
succeeded(const_indices)) {
base_addr = store_op.getBase();
base_indices = std::move(*const_indices);
Value base_addr = store_op.getBase();
SmallVector<Value, 4> base_indices = store_op.getIndices();

if (must_support_unaligned_dynamic_index) {
if (!can_support_unaligned_dynamic_index) {
return op.emitOpError(
"Not implemented: dynamic store with unaligned indices");
}
} else {
// Convert dynamic store to dynamic slice + static store. This saves us a
// bunch of scalar core work.
auto slice_result =
sliceRef(builder, store_op.getBase(),
store_op.getVectorType().getShape(), store_op.getIndices(),
ArrayRef<int64_t>(memref_tiling).take_back(tiled_dims));
if (failed(slice_result)) {
return failure();
}
std::tie(base_addr, base_indices) = *slice_result;
base_addr = slice_result->first;
CHECK_EQ(slice_result->second.size(), base_indices.size());
for (int i = 0; i < base_indices.size(); ++i) {
base_indices[i] = IdxConst(slice_result->second[i], builder, op.getLoc());
}
}
auto tile_base_idxs = ArrayRef<int64_t>(base_indices).take_back(tiled_dims);
auto batch_base_idxs = ArrayRef<int64_t>(base_indices).drop_back(tiled_dims);

// TODO(jevinjiang): ideally we should update the base addr and use static
// indices even for the cases that can skip alignment check. This can save
// us a bunch of scalar core work.
auto tile_base_idxs = ArrayRef<Value>(base_indices).take_back(tiled_dims);
auto batch_base_idxs = ArrayRef<Value>(base_indices).drop_back(tiled_dims);

FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tiles,
disassemble(builder, to_store_layout, store_op.getValueToStore(),
ctx.target_shape));
const int64_t ndims = ty.getRank();
const int64_t base_s = is_1d ? 0 : tile_base_idxs.front();
const int64_t base_l = tile_base_idxs.back();
const auto base_s =
is_1d ? IdxConst(0, builder, op.getLoc()) : tile_base_idxs.front();
const auto base_l = tile_base_idxs.back();
if (is_1d) {
tiles.Reshape(
to_store_layout.implicitShape(toArrayRef(tiles.dimensions())));
Expand All @@ -3984,13 +4079,6 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
}
const SmallVector<int64_t> stored_shape =
to_store_layout.implicitShape(ty.getShape());
int64_t sublane_stride = 1;
// The stride of store should be the number of sublanes in memref tile when
// store a single sublane.
if (to_store_layout.bitwidth() == 32 &&
to_store_layout.tiling() == Tiling{1, ctx.target_shape[1]}) {
sublane_stride = memref_tiling[0];
}
const std::array<int64_t, 2> vreg_slice =
to_store_layout.vregSlice(ctx.target_shape);
const absl::Status status =
Expand All @@ -4002,17 +4090,15 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
const int64_t sidx = *(idx.end() - 2);
const int64_t lidx = *(idx.end() - 1);
SmallVector<Value> indices(ndims);
auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
store_op->getLoc());
for (int64_t i = 0; i < batch_base_idxs.size(); ++i) {
indices[i] = boundIdxConst(batch_base_idxs[i] + idx[i]);
indices[i] = add_idx(batch_base_idxs[i], idx[i]);
}
if (!is_1d) {
*(indices.end() - 2) =
boundIdxConst(base_s + sidx * vreg_slice[0] - *sublane_offset);
add_idx(base_s, sidx * vreg_slice[0] - *sublane_offset);
}
*(indices.end() - 1) =
boundIdxConst(base_l + lidx * vreg_slice[1] - *lane_offset);
add_idx(base_l, lidx * vreg_slice[1] - *lane_offset);
const DenseBoolArrayAttr sublane_mask =
bounds->getSublaneMask(store_op->getContext(), ctx.target_shape);
const bool masks_subelements =
Expand Down Expand Up @@ -4079,7 +4165,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
}
store_op->erase();
return success();
}
}

LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
Expand Down
Loading

0 comments on commit 484d09f

Please sign in to comment.