Skip to content

Commit

Permalink
[RISC-V] Fix gc-related bugs in risc-v emitter (#98226)
Browse files Browse the repository at this point in the history
* [RISC-V] Fix mistakes in emitter

* Revert "[RISC-V] Added designated output instruction emitters (#96741)"

This reverts commit 77fd98c.

* Revert "Revert "[RISC-V] Added designated output instruction emitters (#96741)""

This reverts commit ecc044d.

* [RISC-V] Sync emitOutputIns with the latest ref branch

* [RISC-V] Formatted code

* [RISC-V] Fixes

* [RISC-V] Fixed sign cast in assert code len

* [RISC-V] Readed assert

* [RISC-V] Fixed fence sanity check and removed fence_i

---------

Co-authored-by: Dong-Heon Jung <[email protected]>
  • Loading branch information
Bajtazar and clamp03 authored Feb 13, 2024
1 parent 4d69ab7 commit e3af00b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 67 deletions.
150 changes: 86 additions & 64 deletions src/coreclr/jit/emitriscv64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,7 @@ unsigned emitter::emitOutput_Instr(BYTE* dst, code_t code) const
return sizeof(code_t);
}

static inline void assertCodeLength(unsigned code, uint8_t size)
static inline void assertCodeLength(size_t code, uint8_t size)
{
assert((code >> size) == 0);
}
Expand Down Expand Up @@ -2298,7 +2298,9 @@ static inline void assertCodeLength(unsigned code, uint8_t size)

static constexpr unsigned kInstructionOpcodeMask = 0x7f;
static constexpr unsigned kInstructionFunct3Mask = 0x7000;
static constexpr unsigned kInstructionFunct5Mask = 0xf8000000;
static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
static constexpr unsigned kInstructionFunct2Mask = 0x06000000;

#ifdef DEBUG

Expand Down Expand Up @@ -2338,34 +2340,44 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
assert(isGeneralRegisterOrR0(rs1));
assert(isGeneralRegisterOrR0(rs2));
break;
case INS_fadd_s:
case INS_fsub_s:
case INS_fmul_s:
case INS_fdiv_s:
case INS_fsgnj_s:
case INS_fsgnjn_s:
case INS_fsgnjx_s:
case INS_fmin_s:
case INS_fmax_s:
case INS_feq_s:
case INS_flt_s:
case INS_fle_s:
case INS_fadd_d:
case INS_fsub_d:
case INS_fmul_d:
case INS_fdiv_d:
case INS_fsgnj_d:
case INS_fsgnjn_d:
case INS_fsgnjx_d:
case INS_fmin_d:
case INS_fmax_d:
assert(isFloatReg(rd));
assert(isFloatReg(rs1));
assert(isFloatReg(rs2));
break;
case INS_feq_s:
case INS_feq_d:
case INS_flt_d:
case INS_flt_s:
case INS_fle_s:
case INS_fle_d:
assert(isFloatReg(rd));
assert(isGeneralRegisterOrR0(rd));
assert(isFloatReg(rs1));
assert(isFloatReg(rs2));
break;
case INS_fmv_w_x:
case INS_fmv_d_x:
assert(isFloatReg(rd));
assert(isGeneralRegisterOrR0(rs1));
assert(rs2 == 0);
break;
case INS_fmv_x_d:
case INS_fmv_x_w:
case INS_fclass_s:
case INS_fclass_d:
assert(isGeneralRegisterOrR0(rd));
assert(isFloatReg(rs1));
assert(rs2 == 0);
break;
default:
NO_WAY("Illegal ins within emitOutput_RTypeInstr!");
break;
Expand All @@ -2377,6 +2389,7 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
{
switch (ins)
{
case INS_mov:
case INS_jalr:
case INS_lb:
case INS_lh:
Expand All @@ -2392,7 +2405,6 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
case INS_lwu:
case INS_ld:
case INS_addiw:
case INS_fence_i:
case INS_csrrw:
case INS_csrrs:
case INS_csrrc:
Expand Down Expand Up @@ -2427,6 +2439,15 @@ static constexpr unsigned kInstructionFunct7Mask = 0xfe000000;
assert(rs1 < 32);
assert((opcode & kInstructionFunct7Mask) == 0);
break;
case INS_fence:
{
assert(rd == REG_ZERO);
assert(rs1 == REG_ZERO);
ssize_t format = immediate >> 8;
assert((format == 0) || (format == 0x8));
assert((opcode & kInstructionFunct7Mask) == 0);
}
break;
default:
NO_WAY("Illegal ins within emitOutput_ITypeInstr!");
break;
Expand Down Expand Up @@ -2867,7 +2888,7 @@ BYTE* emitter::emitOutputInstr_OptsI8(BYTE* dst, const instrDesc* id, ssize_t im
if (id->idReg2())
{
// special for INT64_MAX or UINT32_MAX
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, REG_R0, 0xfff);
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, REG_R0, NBitMask(12));
const ssize_t shiftValue = (immediate == INT64_MAX) ? 1 : 32;
dst += emitOutput_ITypeInstr(dst, INS_srli, reg1, reg1, shiftValue);
}
Expand All @@ -2881,10 +2902,10 @@ BYTE* emitter::emitOutputInstr_OptsI8(BYTE* dst, const instrDesc* id, ssize_t im

BYTE* emitter::emitOutputInstr_OptsI32(BYTE* dst, ssize_t immediate, regNumber reg1)
{
ssize_t upperWord = UpperWordOfDoubleWord(immediate);
const ssize_t upperWord = UpperWordOfDoubleWord(immediate);
dst += emitOutput_UTypeInstr(dst, INS_lui, reg1, UpperNBitsOfWordSignExtend<20>(upperWord));
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, reg1, LowerNBitsOfWord<12>(upperWord));
ssize_t lowerWord = LowerWordOfDoubleWord(immediate);
const ssize_t lowerWord = LowerWordOfDoubleWord(immediate);
dst += emitOutput_ITypeInstr(dst, INS_slli, reg1, reg1, 11);
dst += emitOutput_ITypeInstr(dst, INS_addi, reg1, reg1, LowerNBitsOfWord<11>(lowerWord >> 21));
dst += emitOutput_ITypeInstr(dst, INS_slli, reg1, reg1, 11);
Expand All @@ -2899,39 +2920,37 @@ BYTE* emitter::emitOutputInstr_OptsRc(BYTE* dst, const instrDesc* id, instructio
assert(id->idAddr()->iiaIsJitDataOffset());
assert(id->idGCref() == GCT_NONE);

int dataOffs = id->idAddr()->iiaGetJitDataOffset();
const int dataOffs = id->idAddr()->iiaGetJitDataOffset();
assert(dataOffs >= 0);

ssize_t immediate = emitGetInsSC(id);
const ssize_t immediate = emitGetInsSC(id);
assert((immediate >= 0) && (immediate < 0x4000)); // 0x4000 is arbitrary, currently 'imm' is always 0.

unsigned offset = static_cast<unsigned>(dataOffs + immediate);
const unsigned offset = static_cast<unsigned>(dataOffs + immediate);
assert(offset < emitDataSize());

*ins = id->idIns();
regNumber reg1 = id->idReg1();
*ins = id->idIns();
const regNumber reg1 = id->idReg1();

if (id->idIsReloc())
{
return emitOutputInstr_OptsRcReloc(dst, ins, reg1);
return emitOutputInstr_OptsRcReloc(dst, ins, offset, reg1);
}
return emitOutputInstr_OptsRcNoReloc(dst, ins, offset, reg1);
}

BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, regNumber reg1)
BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, unsigned offset, regNumber reg1)
{
ssize_t immediate = emitConsBlock - dst;
assert(immediate > 0);
assert((immediate & 0x03) == 0);
const ssize_t immediate = (emitConsBlock - dst) + offset;
assert((immediate > 0) && ((immediate & 0x03) == 0));

regNumber rsvdReg = codeGen->rsGetRsvdReg();
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
dst += emitOutput_UTypeInstr(dst, INS_auipc, rsvdReg, UpperNBitsOfWordSignExtend<20>(immediate));

instruction lastIns = *ins;

if (*ins == INS_jal)
{
assert(isGeneralRegister(reg1));
*ins = lastIns = INS_addi;
}
dst += emitOutput_ITypeInstr(dst, lastIns, reg1, rsvdReg, LowerNBitsOfWord<12>(immediate));
Expand All @@ -2940,12 +2959,12 @@ BYTE* emitter::emitOutputInstr_OptsRcReloc(BYTE* dst, instruction* ins, regNumbe

BYTE* emitter::emitOutputInstr_OptsRcNoReloc(BYTE* dst, instruction* ins, unsigned offset, regNumber reg1)
{
ssize_t immediate = reinterpret_cast<ssize_t>(emitConsBlock) + offset;
assert((immediate >> 40) == 0);
regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t immediate = reinterpret_cast<ssize_t>(emitConsBlock) + offset;
assertCodeLength(static_cast<size_t>(immediate), 40);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();

instruction lastIns = (*ins == INS_jal) ? (*ins = INS_addi) : *ins;
UINT32 high = immediate >> 11;
const instruction lastIns = (*ins == INS_jal) ? (*ins = INS_addi) : *ins;
const UINT32 high = immediate >> 11;

dst += emitOutput_UTypeInstr(dst, INS_lui, rsvdReg, UpperNBitsOfWordSignExtend<20>(high));
dst += emitOutput_ITypeInstr(dst, INS_addi, rsvdReg, rsvdReg, LowerNBitsOfWord<12>(high));
Expand All @@ -2959,9 +2978,8 @@ BYTE* emitter::emitOutputInstr_OptsRl(BYTE* dst, instrDesc* id, instruction* ins
insGroup* targetInsGroup = static_cast<insGroup*>(emitCodeGetCookie(id->idAddr()->iiaBBlabel));
id->idAddr()->iiaIGlabel = targetInsGroup;

regNumber reg1 = id->idReg1();
assert(isGeneralRegister(reg1));
ssize_t igOffs = targetInsGroup->igOffs;
const regNumber reg1 = id->idReg1();
const ssize_t igOffs = targetInsGroup->igOffs;

if (id->idIsReloc())
{
Expand All @@ -2974,7 +2992,7 @@ BYTE* emitter::emitOutputInstr_OptsRl(BYTE* dst, instrDesc* id, instruction* ins

BYTE* emitter::emitOutputInstr_OptsRlReloc(BYTE* dst, ssize_t igOffs, regNumber reg1)
{
ssize_t immediate = (emitCodeBlock - dst) + igOffs;
const ssize_t immediate = (emitCodeBlock - dst) + igOffs;
assert((immediate & 0x03) == 0);

dst += emitOutput_UTypeInstr(dst, INS_auipc, reg1, UpperNBitsOfWordSignExtend<20>(immediate));
Expand All @@ -2984,11 +3002,11 @@ BYTE* emitter::emitOutputInstr_OptsRlReloc(BYTE* dst, ssize_t igOffs, regNumber

BYTE* emitter::emitOutputInstr_OptsRlNoReloc(BYTE* dst, ssize_t igOffs, regNumber reg1)
{
ssize_t immediate = reinterpret_cast<ssize_t>(emitCodeBlock) + igOffs;
assert((immediate >> (32 + 20)) == 0);
const ssize_t immediate = reinterpret_cast<ssize_t>(emitCodeBlock) + igOffs;
assertCodeLength(static_cast<size_t>(immediate), 32 + 20);

regNumber rsvdReg = codeGen->rsGetRsvdReg();
ssize_t upperSignExt = UpperWordOfDoubleWordDoubleSignExtend<32, 52>(immediate);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t upperSignExt = UpperWordOfDoubleWordDoubleSignExtend<32, 52>(immediate);

dst += emitOutput_UTypeInstr(dst, INS_lui, rsvdReg, UpperNBitsOfWordSignExtend<20>(immediate));
dst += emitOutput_ITypeInstr(dst, INS_addi, rsvdReg, rsvdReg, LowerNBitsOfWord<12>(immediate));
Expand All @@ -3000,32 +3018,32 @@ BYTE* emitter::emitOutputInstr_OptsRlNoReloc(BYTE* dst, ssize_t igOffs, regNumbe

BYTE* emitter::emitOutputInstr_OptsJalr(BYTE* dst, instrDescJmp* jmp, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, jmp) - 4;
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, jmp) - 4;
assert((immediate & 0x03) == 0);

*ins = jmp->idIns();
assert(jmp->idCodeSize() > 4); // The original INS_OPTS_JALR: not used by now!!!
switch (jmp->idCodeSize())
{
case 8:
return emitOutputInstr_OptsJalr8(dst, jmp, *ins, immediate);
return emitOutputInstr_OptsJalr8(dst, jmp, immediate);
case 24:
assert((*ins == INS_jal) || (*ins == INS_j));
assert(jmp->idInsIs(INS_jal, INS_j));
return emitOutputInstr_OptsJalr24(dst, immediate);
case 28:
return emitOutputInstr_OptsJalr28(dst, jmp, *ins, immediate);
return emitOutputInstr_OptsJalr28(dst, jmp, immediate);
default:
// case 0 - 4: The original INS_OPTS_JALR: not used by now!!!
break;
}
unreached();
return nullptr;
}

BYTE* emitter::emitOutputInstr_OptsJalr8(BYTE* dst, const instrDescJmp* jmp, instruction ins, ssize_t immediate)
BYTE* emitter::emitOutputInstr_OptsJalr8(BYTE* dst, const instrDescJmp* jmp, ssize_t immediate)
{
regNumber reg2 = ((ins != INS_beqz) && (ins != INS_bnez)) ? jmp->idReg2() : REG_R0;
const regNumber reg2 = jmp->idInsIs(INS_beqz, INS_bnez) ? REG_R0 : jmp->idReg2();

dst += emitOutput_BTypeInstr_InvertComparation(dst, ins, jmp->idReg1(), reg2, 0x8);
dst += emitOutput_BTypeInstr_InvertComparation(dst, jmp->idIns(), jmp->idReg1(), reg2, 0x8);
dst += emitOutput_JTypeInstr(dst, INS_jal, REG_ZERO, TrimSignedToImm21(immediate));
return dst;
}
Expand All @@ -3034,14 +3052,14 @@ BYTE* emitter::emitOutputInstr_OptsJalr24(BYTE* dst, ssize_t immediate)
{
// Make target address with offset, then jump (JALR) with the target address
immediate -= 2 * 4;
ssize_t high = UpperWordOfDoubleWordSingleSignExtend<0>(immediate);
const ssize_t high = UpperWordOfDoubleWordSingleSignExtend<0>(immediate);

dst += emitOutput_UTypeInstr(dst, INS_lui, REG_RA, UpperNBitsOfWordSignExtend<20>(high));
dst += emitOutput_ITypeInstr(dst, INS_addi, REG_RA, REG_RA, LowerNBitsOfWord<12>(high));
dst += emitOutput_ITypeInstr(dst, INS_slli, REG_RA, REG_RA, 32);

regNumber rsvdReg = codeGen->rsGetRsvdReg();
ssize_t low = LowerWordOfDoubleWord(immediate);
const regNumber rsvdReg = codeGen->rsGetRsvdReg();
const ssize_t low = LowerWordOfDoubleWord(immediate);

dst += emitOutput_UTypeInstr(dst, INS_auipc, rsvdReg, UpperNBitsOfWordSignExtend<20>(low));
dst += emitOutput_RTypeInstr(dst, INS_add, rsvdReg, REG_RA, rsvdReg);
Expand All @@ -3050,17 +3068,18 @@ BYTE* emitter::emitOutputInstr_OptsJalr24(BYTE* dst, ssize_t immediate)
return dst;
}

BYTE* emitter::emitOutputInstr_OptsJalr28(BYTE* dst, const instrDescJmp* jmp, instruction ins, ssize_t immediate)
BYTE* emitter::emitOutputInstr_OptsJalr28(BYTE* dst, const instrDescJmp* jmp, ssize_t immediate)
{
regNumber reg2 = ((ins != INS_beqz) && (ins != INS_bnez)) ? jmp->idReg2() : REG_R0;
dst += emitOutput_BTypeInstr_InvertComparation(dst, ins, jmp->idReg1(), reg2, 0x1c);
regNumber reg2 = jmp->idInsIs(INS_beqz, INS_bnez) ? REG_R0 : jmp->idReg2();

dst += emitOutput_BTypeInstr_InvertComparation(dst, jmp->idIns(), jmp->idReg1(), reg2, 0x1c);

return emitOutputInstr_OptsJalr24(dst, immediate);
}

BYTE* emitter::emitOutputInstr_OptsJCond(BYTE* dst, instrDesc* id, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));

*ins = id->idIns();

Expand All @@ -3070,7 +3089,7 @@ BYTE* emitter::emitOutputInstr_OptsJCond(BYTE* dst, instrDesc* id, const insGrou

BYTE* emitter::emitOutputInstr_OptsJ(BYTE* dst, instrDesc* id, const insGroup* ig, instruction* ins)
{
ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
const ssize_t immediate = emitOutputInstrJumpDistance(dst, ig, static_cast<instrDescJmp*>(id));
assert((immediate & 0x03) == 0);

*ins = id->idIns();
Expand Down Expand Up @@ -3133,11 +3152,13 @@ BYTE* emitter::emitOutputInstr_OptsC(BYTE* dst, instrDesc* id, const insGroup* i
size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
{
BYTE* dst = *dp;
BYTE* dst2 = dst + 4;
const BYTE* const odst = *dp;
instruction ins;
size_t sz = 0;

assert(REG_NA == static_cast<int>(REG_NA));
assert(writeableOffset == 0);

insOpts insOp = id->idInsOpt();

Expand Down Expand Up @@ -3174,8 +3195,9 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
sz = sizeof(instrDescJmp);
break;
case INS_OPTS_C:
dst = emitOutputInstr_OptsC(dst, id, ig, &sz);
ins = INS_nop;
dst = emitOutputInstr_OptsC(dst, id, ig, &sz);
dst2 = dst;
ins = INS_nop;
break;
default: // case INS_OPTS_NONE:
dst += emitOutput_Instr(dst, id->idAddr()->iiaGetInstrEncode());
Expand All @@ -3193,11 +3215,11 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
// We assume that "idReg1" is the primary destination register for all instructions
if (id->idGCref() != GCT_NONE)
{
emitGCregLiveUpd(id->idGCref(), id->idReg1(), dst);
emitGCregLiveUpd(id->idGCref(), id->idReg1(), dst2);
}
else
{
emitGCregDeadUpd(id->idReg1(), dst);
emitGCregDeadUpd(id->idReg1(), dst2);
}
}

Expand All @@ -3211,7 +3233,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
int adr = emitComp->lvaFrameAddress(varNum, &FPbased);
if (id->idGCref() != GCT_NONE)
{
emitGCvarLiveUpd(adr + ofs, varNum, id->idGCref(), dst DEBUG_ARG(varNum));
emitGCvarLiveUpd(adr + ofs, varNum, id->idGCref(), dst2 DEBUG_ARG(varNum));
}
else
{
Expand All @@ -3228,7 +3250,7 @@ size_t emitter::emitOutputInstr(insGroup* ig, instrDesc* id, BYTE** dp)
vt = tmpDsc->tdTempType();
}
if (vt == TYP_REF || vt == TYP_BYREF)
emitGCvarDeadUpd(adr + ofs, dst DEBUG_ARG(varNum));
emitGCvarDeadUpd(adr + ofs, dst2 DEBUG_ARG(varNum));
}
// if (emitInsWritesToLclVarStackLocPair(id))
//{
Expand Down
Loading

0 comments on commit e3af00b

Please sign in to comment.