Skip to content

Commit

Permalink
[FIRRTL][CAPI] Add function for getting mask type
Browse files Browse the repository at this point in the history
  • Loading branch information
SpriteOvO committed Apr 22, 2024
1 parent db973f1 commit 172b5ea
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/circt-c/Dialect/FIRRTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ MLIR_CAPI_EXPORTED MlirType
firrtlTypeGetClass(MlirContext ctx, MlirAttribute name, size_t numberOfElements,
const FIRRTLClassElement *elements);

MLIR_CAPI_EXPORTED MlirType firrtlTypeGetMaskType(MlirType type);

//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions lib/CAPI/Dialect/FIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ MlirType firrtlTypeGetClass(MlirContext ctx, MlirAttribute name,
return wrap(ClassType::get(unwrap(ctx), nameSymbol, classElements));
}

MlirType firrtlTypeGetMaskType(MlirType type) {
auto baseType = type_dyn_cast<FIRRTLBaseType>(unwrap(type));
assert(baseType && "unexpected type, must be base type");
return wrap(baseType.getMaskType());
}

//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//
Expand Down
50 changes: 50 additions & 0 deletions test/CAPI/firrtl.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <stdlib.h>
#include <string.h>

#define ARRAY_SIZE(arr) (sizeof((arr)) / sizeof((arr)[0]))

void dumpCallback(MlirStringRef message, void *userData) {
fprintf(stderr, "%.*s", (int)message.length, message.data);
}
Expand Down Expand Up @@ -184,12 +186,60 @@ void testAttrGetIntegerFromString(MlirContext ctx) {
mlirStringRefCreateFromCString("114514"), 10));
}

void testTypeGetMaskType(MlirContext ctx) {
assert(mlirTypeEqual(firrtlTypeGetMaskType(firrtlTypeGetUInt(ctx, 32)),
firrtlTypeGetUInt(ctx, 1)));
assert(mlirTypeEqual(firrtlTypeGetMaskType(firrtlTypeGetSInt(ctx, 64)),
firrtlTypeGetUInt(ctx, 1)));

FIRRTLBundleField lhsFields[] = {
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f1")),
.isFlip = false,
.type = firrtlTypeGetUInt(ctx, 32),
},
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f2")),
.isFlip = false,
.type = firrtlTypeGetSInt(ctx, 64),
},
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f3")),
.isFlip = false,
.type = firrtlTypeGetClock(ctx),
},
};
FIRRTLBundleField rhsFields[] = {
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f1")),
.isFlip = false,
.type = firrtlTypeGetUInt(ctx, 1),
},
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f2")),
.isFlip = false,
.type = firrtlTypeGetUInt(ctx, 1),
},
{
.name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("f3")),
.isFlip = false,
.type = firrtlTypeGetUInt(ctx, 1),
},
};
MlirType lhsBundle =
firrtlTypeGetBundle(ctx, ARRAY_SIZE(lhsFields), lhsFields);
MlirType rhsBundle =
firrtlTypeGetBundle(ctx, ARRAY_SIZE(rhsFields), rhsFields);
assert(mlirTypeEqual(firrtlTypeGetMaskType(lhsBundle), rhsBundle));
}

int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleLoadDialect(mlirGetDialectHandle__firrtl__(), ctx);
testExport(ctx);
testValueFoldFlow(ctx);
testImportAnnotations(ctx);
testAttrGetIntegerFromString(ctx);
testTypeGetMaskType(ctx);
return 0;
}

0 comments on commit 172b5ea

Please sign in to comment.