Skip to content

Commit

Permalink
[HWLegalizeModules] Lower types-like packed array handling (#5355)
Browse files Browse the repository at this point in the history
This PR refactors HWLegalizeModules to something more like existing lower types passes in order to support more complex patterns such as array concatenations. Several new tests are also added.

Fix #5355.
  • Loading branch information
yupferris committed Nov 10, 2023
1 parent d7088aa commit bf712b0
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 67 deletions.
329 changes: 266 additions & 63 deletions lib/Dialect/SV/Transforms/HWLegalizeModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ struct HWLegalizeModulesPass

private:
void processPostOrder(Block &block);
Operation *tryLoweringArrayGet(hw::ArrayGetOp getOp);
bool tryLoweringPackedArrayOp(Operation &op);
Value lowerLookupToCasez(Operation &op, Value input, Value index,
mlir::Type elementType,
SmallVector<Value> caseValues);
bool processUsers(Operation &op, Value value, ArrayRef<Value> mapping);
std::optional<std::pair<uint64_t, unsigned>>
tryExtractIndexAndBitWidth(Value value);

/// This is the current hw.module being processed.
hw::HWModuleOp thisHWModule;
Expand All @@ -52,52 +58,218 @@ struct HWLegalizeModulesPass
};
} // end anonymous namespace

/// Try to lower a hw.array_get in module that doesn't support packed arrays.
/// This returns a replacement operation if lowering was successful, null
/// otherwise.
Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) {
SmallVector<Value> caseValues;
OpBuilder builder(&thisHWModule.getBodyBlock()->front());
// If the operand is an array_create or aggregate constant, then we can lower
// this into a casez.
if (auto createOp = getOp.getInput().getDefiningOp<hw::ArrayCreateOp>())
caseValues = SmallVector<Value>(llvm::reverse(createOp.getOperands()));
else if (auto aggregateConstant =
getOp.getInput().getDefiningOp<hw::AggregateConstantOp>()) {
for (auto elem : llvm::reverse(aggregateConstant.getFields())) {
if (auto intAttr = dyn_cast<IntegerAttr>(elem))
caseValues.push_back(builder.create<hw::ConstantOp>(
aggregateConstant.getLoc(), intAttr));
else
caseValues.push_back(builder.create<hw::AggregateConstantOp>(
aggregateConstant.getLoc(), getOp.getType(),
elem.cast<ArrayAttr>()));
}
} else {
return nullptr;
}
bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) {
return TypeSwitch<Operation *, bool>(&op)
.Case<hw::AggregateConstantOp>([&](hw::AggregateConstantOp constOp) {
// Replace individual element uses (if any) with input fields.
SmallVector<Value> inputs;
OpBuilder builder(constOp);
for (auto field : llvm::reverse(constOp.getFields())) {
if (auto intAttr = dyn_cast<IntegerAttr>(field))
inputs.push_back(
builder.create<hw::ConstantOp>(constOp.getLoc(), intAttr));
else
inputs.push_back(builder.create<hw::AggregateConstantOp>(
constOp.getLoc(), constOp.getType(), field.cast<ArrayAttr>()));
}
if (!processUsers(op, constOp.getResult(), inputs))
return false;

// Remove original op.
return true;
})
.Case<hw::ArrayConcatOp>([&](hw::ArrayConcatOp concatOp) {
// Redirect individual element uses (if any) to the input arguments.
SmallVector<std::pair<Value, uint64_t>> arrays;
for (auto array : llvm::reverse(concatOp.getInputs())) {
auto ty = hw::type_cast<hw::ArrayType>(array.getType());
arrays.emplace_back(array, ty.getNumElements());
}
for (auto *user :
llvm::make_early_inc_range(concatOp.getResult().getUsers())) {
if (TypeSwitch<Operation *, bool>(user)
.Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
if (auto indexAndBitWidth =
tryExtractIndexAndBitWidth(getOp.getIndex())) {
auto [indexValue, bitWidth] = *indexAndBitWidth;
// FIXME: More efficient search
for (const auto &[array, size] : arrays) {
if (indexValue >= size) {
indexValue -= size;
continue;
}
OpBuilder builder(getOp);
getOp.getInputMutable().set(array);
getOp.getIndexMutable().set(
builder.createOrFold<hw::ConstantOp>(
getOp.getLoc(), APInt(bitWidth, indexValue)));
return true;
}
}

return false;
})
.Default([](auto op) { return false; }))
continue;

op.emitError("unsupported packed array expression");
signalPassFailure();
}

// Remove the original op.
return true;
})
.Case<hw::ArrayCreateOp>([&](hw::ArrayCreateOp createOp) {
// Replace individual element uses (if any) with input arguments.
SmallVector<Value> inputs(llvm::reverse(createOp.getInputs()));
if (!processUsers(op, createOp.getResult(), inputs))
return false;

// array_get(idx, array_create(a,b,c,d)) ==> casez(idx).
Value index = getOp.getIndex();
// Remove original op.
return true;
})
.Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
// Skip index ops with constant index.
auto index = getOp.getIndex();
if (auto *definingOp = index.getDefiningOp())
if (isa<hw::ConstantOp>(definingOp))
return false;

// Create the wire for the result of the casez in the hw.module.
auto theWire = builder.create<sv::RegOp>(getOp.getLoc(), getOp.getType(),
// Generate case value element lookups.
auto ty = hw::type_cast<hw::ArrayType>(getOp.getInput().getType());
OpBuilder builder(getOp);
SmallVector<Value> caseValues;
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
auto loc = op.getLoc();
auto index = builder.createOrFold<hw::ConstantOp>(
loc, APInt(llvm::Log2_64_Ceil(e), i));
auto element =
builder.create<hw::ArrayGetOp>(loc, getOp.getInput(), index);
caseValues.push_back(element);
}

// Transform array index op into casez statement.
auto theWire = lowerLookupToCasez(op, getOp.getInput(), index,
ty.getElementType(), caseValues);

// Emit the read from the wire, replace uses and clean up.
builder.setInsertionPoint(getOp);
auto readWire =
builder.create<sv::ReadInOutOp>(getOp.getLoc(), theWire);
getOp.getResult().replaceAllUsesWith(readWire);
return true;
})
.Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
// Skip index ops with constant index.
auto index = indexOp.getIndex();
if (auto *definingOp = index.getDefiningOp())
if (isa<hw::ConstantOp>(definingOp))
return false;

// Skip index ops with unpacked arrays.
auto inout = indexOp.getInput().getType();
if (hw::type_isa<hw::UnpackedArrayType>(inout.getElementType()))
return false;

// Generate case value element lookups.
auto ty = hw::type_cast<hw::ArrayType>(inout.getElementType());
OpBuilder builder(&op);
SmallVector<Value> caseValues;
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
auto loc = op.getLoc();
auto index = builder.createOrFold<hw::ConstantOp>(
loc, APInt(llvm::Log2_64_Ceil(e), i));
auto element = builder.create<sv::ArrayIndexInOutOp>(
loc, indexOp.getInput(), index);
auto readElement = builder.create<sv::ReadInOutOp>(loc, element);
caseValues.push_back(readElement);
}

// Transform array index op into casez statement.
auto theWire = lowerLookupToCasez(op, indexOp.getInput(), index,
ty.getElementType(), caseValues);

// Replace uses and clean up.
indexOp.getResult().replaceAllUsesWith(theWire);
return true;
})
.Case<sv::PAssignOp>([&](sv::PAssignOp assignOp) {
// Transform array assignment into individual assignments for each array
// element.
auto inout = assignOp.getDest().getType();
auto ty = hw::type_dyn_cast<hw::ArrayType>(inout.getElementType());
if (!ty)
return false;

OpBuilder builder(assignOp);
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
auto loc = op.getLoc();
auto index = builder.createOrFold<hw::ConstantOp>(
loc, APInt(llvm::Log2_64_Ceil(e), i));
auto dstElement = builder.create<sv::ArrayIndexInOutOp>(
loc, assignOp.getDest(), index);
auto srcElement =
builder.create<hw::ArrayGetOp>(loc, assignOp.getSrc(), index);
builder.create<sv::PAssignOp>(loc, dstElement, srcElement);
}

// Remove original assignment.
return true;
})
.Case<sv::RegOp>([&](sv::RegOp regOp) {
// Transform array reg into individual regs for each array element.
auto ty = hw::type_dyn_cast<hw::ArrayType>(regOp.getElementType());
if (!ty)
return false;

OpBuilder builder(regOp);
auto name = StringAttr::get(regOp.getContext(), "name");
SmallVector<Value> elements;
for (size_t i = 0, e = ty.getNumElements(); i < e; i++) {
auto loc = op.getLoc();
auto element = builder.create<sv::RegOp>(loc, ty.getElementType());
if (auto nameAttr = regOp->getAttrOfType<StringAttr>(name)) {
element.setNameAttr(
StringAttr::get(regOp.getContext(), nameAttr.getValue()));
}
elements.push_back(element);
}

// Fix users to refer to individual element regs.
if (!processUsers(op, regOp.getResult(), elements))
return false;

// Remove original reg.
return true;
})
.Default([&](auto op) { return false; });
}

Value HWLegalizeModulesPass::lowerLookupToCasez(Operation &op, Value input,
Value index,
mlir::Type elementType,
SmallVector<Value> caseValues) {
// Create the wire for the result of the casez in the
// hw.module.
OpBuilder builder(&op);
auto theWire = builder.create<sv::RegOp>(op.getLoc(), elementType,
builder.getStringAttr("casez_tmp"));
builder.setInsertionPoint(getOp);
builder.setInsertionPoint(&op);

auto loc = getOp.getInput().getDefiningOp()->getLoc();
// A casez is a procedural operation, so if we're in a non-procedural region
// we need to inject an always_comb block.
if (!getOp->getParentOp()->hasTrait<sv::ProceduralRegion>()) {
auto loc = input.getDefiningOp()->getLoc();
// A casez is a procedural operation, so if we're in a
// non-procedural region we need to inject an always_comb
// block.
if (!op.getParentOp()->hasTrait<sv::ProceduralRegion>()) {
auto alwaysComb = builder.create<sv::AlwaysCombOp>(loc);
builder.setInsertionPointToEnd(alwaysComb.getBodyBlock());
}

// If we are missing elements in the array (it is non-power of two), then
// add a default 'X' value.
// If we are missing elements in the array (it is non-power of
// two), then add a default 'X' value.
if (1ULL << index.getType().getIntOrFloatBitWidth() != caseValues.size()) {
caseValues.push_back(
builder.create<sv::ConstantXOp>(getOp.getLoc(), getOp.getType()));
caseValues.push_back(builder.create<sv::ConstantXOp>(
op.getLoc(), op.getResult(0).getType()));
}

APInt caseValue(index.getType().getIntOrFloatBitWidth(), 0);
Expand All @@ -107,9 +279,10 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) {
builder.create<sv::CaseOp>(
loc, CaseStmtType::CaseZStmt, index, caseValues.size(),
[&](size_t caseIdx) -> std::unique_ptr<sv::CasePattern> {
// Use a default pattern for the last value, even if we are complete.
// This avoids tools thinking they need to insert a latch due to
// potentially incomplete case coverage.
// Use a default pattern for the last value, even if we
// are complete. This avoids tools thinking they need to
// insert a latch due to potentially incomplete case
// coverage.
bool isDefault = caseIdx == caseValues.size() - 1;
Value theValue = caseValues[caseIdx];
std::unique_ptr<sv::CasePattern> thePattern;
Expand All @@ -123,12 +296,52 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) {
return thePattern;
});

// Ok, emit the read from the wire to get the value out.
builder.setInsertionPoint(getOp);
auto readWire = builder.create<sv::ReadInOutOp>(getOp.getLoc(), theWire);
getOp.getResult().replaceAllUsesWith(readWire);
getOp->erase();
return readWire;
return theWire;
}

bool HWLegalizeModulesPass::processUsers(Operation &op, Value value,
ArrayRef<Value> mapping) {
for (auto *user : llvm::make_early_inc_range(value.getUsers())) {
if (TypeSwitch<Operation *, bool>(user)
.Case<hw::ArrayGetOp>([&](hw::ArrayGetOp getOp) {
if (auto indexAndBitWidth =
tryExtractIndexAndBitWidth(getOp.getIndex())) {
getOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
return true;
}

return false;
})
.Case<sv::ArrayIndexInOutOp>([&](sv::ArrayIndexInOutOp indexOp) {
if (auto indexAndBitWidth =
tryExtractIndexAndBitWidth(indexOp.getIndex())) {
indexOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]);
return true;
}

return false;
})
.Default([](auto op) { return false; })) {
user->erase();
continue;
}

user->emitError("unsupported packed array expression");
signalPassFailure();
return false;
}

return true;
}

std::optional<std::pair<uint64_t, unsigned>>
HWLegalizeModulesPass::tryExtractIndexAndBitWidth(Value value) {
if (auto constantOp = dyn_cast<hw::ConstantOp>(value.getDefiningOp())) {
auto index = constantOp.getValue();
return std::make_optional(
std::make_pair(index.getZExtValue(), index.getBitWidth()));
}
return std::nullopt;
}

void HWLegalizeModulesPass::processPostOrder(Block &body) {
Expand All @@ -154,26 +367,16 @@ void HWLegalizeModulesPass::processPostOrder(Block &body) {
}

if (options.disallowPackedArrays) {
// Try idioms for lowering array_get operations.
if (auto getOp = dyn_cast<hw::ArrayGetOp>(op))
if (auto *replacement = tryLoweringArrayGet(getOp)) {
it = Block::iterator(replacement);
anythingChanged = true;
continue;
}

// If this is a dead array, then we can just delete it. This is
// probably left over from get/create lowering.
if (isa<hw::ArrayCreateOp, hw::AggregateConstantOp>(op) &&
op.use_empty()) {
// Try supported packed array op lowering.
if (tryLoweringPackedArrayOp(op)) {
it = --Block::iterator(op);
op.erase();
anythingChanged = true;
continue;
}

// Otherwise, if we aren't allowing multi-dimensional arrays, reject the
// IR as invalid.
// TODO: We should eventually implement a "lower types" like feature in
// this pass.
// Otherwise, if the IR produces a packed array and we aren't allowing
// multi-dimensional arrays, reject the IR as invalid.
for (auto value : op.getResults()) {
if (value.getType().isa<hw::ArrayType>()) {
op.emitError("unsupported packed array expression");
Expand Down
Loading

0 comments on commit bf712b0

Please sign in to comment.