Skip to content

Commit

Permalink
Merge pull request #31 from huxuan0307/gh-rvv-cpu
Browse files Browse the repository at this point in the history
arch-riscv: Add vcompress.vm inst
  • Loading branch information
ksco authored May 30, 2023
2 parents 0d3b7b4 + aa32f06 commit 2909be3
Show file tree
Hide file tree
Showing 9 changed files with 588 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/arch/isa_parser/operand_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def makeRead(self):

def makeWrite(self):
return f'''
xc->setRegOperand(this, 0, &tmp_d{self.dest_reg_idx});
xc->setRegOperand(this,{self.dest_reg_idx}, &tmp_d{self.dest_reg_idx});
if (traceData) {{
traceData->setData(tmp_d{self.dest_reg_idx});
}}
Expand Down
71 changes: 71 additions & 0 deletions src/arch/riscv/insts/vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ std::string VleMicroInst::generateDisassembly(Addr pc,
return ss.str();
}

std::string VleffMicroInst::generateDisassembly(Addr pc,
const loader::SymbolTable *symtab) const
{
std::stringstream ss;
ss << mnemonic << ' ' << registerName(destRegIdx(0)) << ", "
<< VLENB * microIdx << '(' << registerName(srcRegIdx(0)) << ')' << ", "
<< registerName(srcRegIdx(1));
if (!machInst.vm) ss << ", v0.t";
return ss.str();
}

std::string VlWholeMicroInst::generateDisassembly(Addr pc,
const loader::SymbolTable *symtab) const
{
Expand Down Expand Up @@ -387,5 +398,65 @@ VMvWholeMicroInst::generateDisassembly(Addr pc,
return ss.str();
}

VleffEndMicroInst::VleffEndMicroInst(ExtMachInst extMachInst, uint8_t _numSrcs)
: VectorMicroInst("VleffEnd", extMachInst,
VectorIntegerArithOp, 0, 0)
{
setRegIdxArrays(
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::srcRegIdxArr),
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::destRegIdxArr));
_numSrcRegs = 0;
_numDestRegs = 0;
for (uint8_t i = 0; i < _numSrcs; i++) {
setSrcRegIdx(_numSrcRegs++, vecRegClass[VecMemInternalReg0 + i]);
}
this->numSrcs = _numSrcs;
printf("VleffEndMicroInst numSrc: %hhu, numDestRegs: %hhu\n", this->numSrcs, _numDestRegs);

flags[IsNonSpeculative] = true;
flags[IsSerializeAfter] = true;
}

Fault
VleffEndMicroInst::execute(ExecContext* xc, Trace::InstRecord* traceData) const
{
printf("VleffEndMicroInst::execute begin\n");
vreg_t cnt[8];
for (uint8_t i = 0; i < this->numSrcs; i++) {
xc->getRegOperand(this, i, cnt + i);
}

printf("VleffEndMicroInst::execute getRegOperand done\n");

// [[maybe_unused]]uint64_t vl = *(uint64_t*)xc->getWritableRegOperand(this, 0);
printf("VleffEndMicroInst::execute getWritableRegOperand done\n");

uint64_t new_vl = 0;
for (uint8_t i = 0; i < this->numSrcs; i++) {
new_vl += cnt[i].as<uint64_t>()[0];
}
printf("VleffEndMicroInst::execute new_vl sum done\n");

// xc->setRegOperand(this, 0, new_vl);
xc->setMiscReg(MISCREG_VL, new_vl);

printf("VleffEndMicroInst::execute setRegOperand done\n");

if (traceData)
traceData->setData(new_vl);
printf("VleffEndMicroInst::execute end\n");
return NoFault;
}

std::string
VleffEndMicroInst::generateDisassembly(Addr pc, const loader::SymbolTable *symtab) const
{
std::stringstream ss;
ss << mnemonic << ' ' << registerName(destRegIdx(0));
return ss.str();
}

} // namespace RiscvISA
} // namespace gem5
260 changes: 260 additions & 0 deletions src/arch/riscv/insts/vector.hh
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@

#include <string>

#include "arch/riscv/faults.hh"
#include "arch/riscv/insts/static_inst.hh"
#include "arch/riscv/regs/misc.hh"
#include "arch/riscv/utility.hh"
#include "base/bitfield.hh"
#include "cpu/exec_context.hh"
#include "cpu/static_inst.hh"

Expand Down Expand Up @@ -324,6 +326,36 @@ class VseMicroInst : public VectorMicroInst
Addr pc, const loader::SymbolTable *symtab) const override;
};

class VleffMicroInst : public VectorMicroInst
{
protected:
Request::Flags memAccessFlags;

VleffMicroInst(const char *mnem, ExtMachInst _machInst, OpClass __opClass,
uint8_t _microVl, uint8_t _microIdx)
: VectorMicroInst(mnem, _machInst, __opClass, _microVl, _microIdx)
{
this->flags[IsLoad] = true;
}

std::string generateDisassembly(
Addr pc, const loader::SymbolTable *symtab) const override;
};

class VleffEndMicroInst : public VectorMicroInst
{
private:
RegId srcRegIdxArr[8]; // vle tmp target, used to keep RAW sequence
RegId destRegIdxArr[1]; // vstart
uint8_t numSrcs;
public:
VleffEndMicroInst(ExtMachInst extMachInst, uint8_t _numSrcs);

Fault execute(ExecContext* xc, Trace::InstRecord* traceData) const override;

std::string generateDisassembly(Addr pc, const loader::SymbolTable *symtab) const override;
};

class VlWholeMacroInst : public VectorMemMacroInst
{
protected:
Expand Down Expand Up @@ -621,6 +653,234 @@ class VxsatMicroInst : public VectorArithMicroInst
}
};

class VCompressPopcMicroInst : public VectorArithMicroInst
{
private:
RegId srcRegIdxArr[1]; // vm
RegId destRegIdxArr[1]; // vcnt

public:
VCompressPopcMicroInst(ExtMachInst extMachInst)
: VectorArithMicroInst("VPopCount", extMachInst,
VectorIntegerArithOp, 0, 0)
{
setRegIdxArrays(reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::srcRegIdxArr),
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::destRegIdxArr));
_numSrcRegs = 0;
_numDestRegs = 0;
setDestRegIdx(_numDestRegs++, vecRegClass[VecMemInternalReg0]);
_numTypedDestRegs[VecRegClass]++;
setSrcRegIdx(_numSrcRegs++, vecRegClass[extMachInst.vs1]);
}

Fault execute(ExecContext* xc, Trace::InstRecord* traceData) const override
{

const int8_t vlmul = vtype_vlmul(vtype);
const size_t sew = vtype_SEW(vtype);
const size_t countN = VLEN / sew;
const uint32_t numVRegs = 1 << std::max<int64_t>(0, vlmul);

size_t cnt[8] = {0};
auto popcount_in_byte = [](uint8_t* addr, uint8_t msb, uint8_t lsb)
-> int {
return popCount(mask(msb, lsb) & *addr);
};

auto popcount_byte = [](uint8_t* l_addr, uint8_t* r_addr) -> int {
size_t res = 0;
while (l_addr < r_addr) {
res += popCount(*l_addr);
l_addr++;
}
return res;
};

vreg_t vs;
xc->getRegOperand(this, 0, &vs);
vreg_t vd = *(vreg_t *)xc->getWritableRegOperand(this, 0);

for (int i = 0; i < std::max<int8_t>(1, numVRegs); i++) {
uint8_t* base_addr = vs.as<uint8_t>() + i*countN/8;
if (countN < 8) {
cnt[i] = popcount_in_byte(base_addr, i * countN % 8,
(i + 1)*countN % 8);
} else {
cnt[i] = popcount_byte(base_addr, base_addr + countN / 8);
}
}

for (int i = 0; i < std::max<int8_t>(1, numVRegs); i++) {
vd.as<uint8_t>()[i] = cnt[i];
}

xc->setRegOperand(this, 0, &vd);
if (traceData)
traceData->setData(vd);
return NoFault;
}

std::string generateDisassembly(Addr pc, const loader::SymbolTable *symtab)
const override
{
std::stringstream ss;
ss << mnemonic << ' ' << registerName(destRegIdx(0)) << ", "
<< registerName(srcRegIdx(0)) << ", ";
return ss.str();
}

};

template<typename Type>
class VCompressMicroInst : public VectorArithMicroInst
{
private:
RegId srcRegIdxArr[4]; // vs, vcnt, vm, old_vd
RegId destRegIdxArr[1]; // vd
uint8_t vsIdx;
uint8_t vdIdx;
public:
VCompressMicroInst(ExtMachInst extMachInst, uint8_t microVl,
uint8_t microIdx, uint8_t _vsIdx, uint8_t _vdIdx)
: VectorArithMicroInst("Vcompress_micro", extMachInst,
VectorIntegerArithOp, microVl, microIdx)
, vsIdx(_vsIdx), vdIdx(_vdIdx)
{
setRegIdxArrays(
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::srcRegIdxArr),
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::destRegIdxArr));

_numSrcRegs = 0;
_numDestRegs = 0;
setDestRegIdx(_numDestRegs++, vecRegClass[extMachInst.vd + _vdIdx]);
_numTypedDestRegs[VecRegClass]++;
// vs
setSrcRegIdx(_numSrcRegs++, vecRegClass[extMachInst.vs2 + _vsIdx]);
// vcnt
setSrcRegIdx(_numSrcRegs++, vecRegClass[VecMemInternalReg0]);
// vm
setSrcRegIdx(_numSrcRegs++, vecRegClass[extMachInst.vs1]);
// old_vd
setSrcRegIdx(_numSrcRegs++, vecRegClass[extMachInst.vd + _vdIdx]);
}

Fault execute(ExecContext* xc, Trace::InstRecord* traceData) const override
{
const int sew = vtype_SEW(vtype);
const int uvlmax = VLEN / sew;

vreg_t vs, vcnt, vm, old_vd;
xc->getRegOperand(this, 0, &vs);
xc->getRegOperand(this, 1, &vcnt);
xc->getRegOperand(this, 2, &vm);
xc->getRegOperand(this, 3, &old_vd);

vreg_t vd = *(vreg_t *)xc->getWritableRegOperand(this, 0);
memcpy(vd.as<uint8_t>(), old_vd.as<uint8_t>(), VLENB);

vreg_t vtmp;

auto vcnt_get_elem = [&](int idx) -> size_t {
return vcnt.as<uint8_t>()[idx];
};

int num_vs_elem_moved = 0;
int num_vd_elem_moved = vdIdx * uvlmax;
for (int i = 0; i < vsIdx; i++) {
num_vs_elem_moved += vcnt_get_elem(i);
}

int vtmpIdx = 0;
for (int i = 0; i < microVl; i++) {
if (elem_mask(vm.as<uint8_t>(), i + vsIdx * uvlmax)) {
vtmp.as<Type>()[vtmpIdx++] = vs.as<Type>()[i];
}
}
int vsElemIdxBase = std::max(0, num_vd_elem_moved - num_vs_elem_moved);
int vdElemIdxBase = std::max(0, num_vs_elem_moved - num_vd_elem_moved);

for (; vsElemIdxBase < vtmpIdx && vdElemIdxBase < microVl;) {
vd.as<Type>()[vdElemIdxBase++] = vtmp.as<Type>()[vsElemIdxBase++];
}
xc->setRegOperand(this, 0, &vd);
if (traceData)
traceData->setData(vd);
return NoFault;
}

std::string generateDisassembly(Addr pc, const loader::SymbolTable *symtab)
const override
{
std::stringstream ss;
ss << mnemonic << ' ' << registerName(destRegIdx(0)) << ", "
<< registerName(srcRegIdx(0)) << ", "
<< registerName(srcRegIdx(1)) << ", "
<< registerName(srcRegIdx(2)) << ", "
<< registerName(srcRegIdx(3)) << ", ";
return ss.str();
}
};

template<typename Type>
class Vcompress_vm : public VectorArithMacroInst
{
private:
RegId srcRegIdxArr[2]; // vs, vm
RegId destRegIdxArr[1]; // vd
public:
Vcompress_vm(ExtMachInst _machInst)
: VectorArithMacroInst("vcompress_vm", _machInst, VectorIntegerArithOp)
{
setRegIdxArrays(
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::srcRegIdxArr),
reinterpret_cast<RegIdArrayPtr>(
&std::remove_pointer_t<decltype(this)>::destRegIdxArr));
_numSrcRegs = 0;
_numDestRegs = 0;
setDestRegIdx(_numDestRegs++, vecRegClass[_machInst.vd]);
_numTypedDestRegs[VecRegClass]++;
setSrcRegIdx(_numSrcRegs++, vecRegClass[_machInst.vs1]);
setSrcRegIdx(_numSrcRegs++, vecRegClass[_machInst.vs2]);

const uint32_t num_microops = vtype_regs_per_group(vtype);
int32_t tmp_vl = this->vl;
const int32_t micro_vlmax = vtype_VLMAX(_machInst.vtype8, true);
int32_t micro_vl = std::min(tmp_vl, micro_vlmax);

StaticInstPtr microop;
microop = new VCompressPopcMicroInst(_machInst);
this->microops.push_back(microop);

int8_t microIdx = 0;
for (int i = 0; i < num_microops && micro_vl > 0; ++i) {
for (int j = 0; j <= i; ++j) {
microop = new VCompressMicroInst<Type>(
_machInst, micro_vl, microIdx++, i, j);
microop->setDelayedCommit();
this->microops.push_back(microop);
}
micro_vl = std::min(tmp_vl -= micro_vlmax, micro_vlmax);
}
this->microops.front()->setFirstMicroop();
this->microops.back()->setLastMicroop();
}

std::string generateDisassembly(Addr pc, const loader::SymbolTable *symtab)
const override
{
std::stringstream ss;
ss << mnemonic << ' ' << registerName(destRegIdx(0)) << ", "
<< registerName(srcRegIdx(1)) << ", "
<< registerName(srcRegIdx(0));
return ss.str();
}
};

} // namespace RiscvISA
} // namespace gem5

Expand Down
Loading

0 comments on commit 2909be3

Please sign in to comment.