Skip to content

Commit

Permalink
fix: remove incorrect nop transpilation for intrinsics (#1306)
Browse files Browse the repository at this point in the history
* fix: remove incorrect nop transpilation for intrinsics

* feat: document that iseqmod is a no-op if rd = x0

* feat: specify validity for setup_iseq
  • Loading branch information
yi-sun authored Jan 27, 2025
1 parent a7fc816 commit dc9b8ae
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 36 deletions.
6 changes: 5 additions & 1 deletion crates/toolchain/transpiler/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ pub fn from_r_type<F: PrimeField32>(
opcode: usize,
e_as: usize,
dec_insn: &RType,
allow_rd_zero: bool,
) -> Instruction<F> {
if dec_insn.rd == 0 {
// If `rd` is not allowed to be zero, we transpile to `NOP` to prevent a write
// to `x0`. In the cases where `allow_rd_zero` is true, it is the responsibility of
// the caller to guarantee that the resulting instruction does not write to `rd`.
if !allow_rd_zero && dec_insn.rd == 0 {
return nop();
}
Instruction::new(
Expand Down
8 changes: 4 additions & 4 deletions docs/specs/RISCV.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ We use `config.mod_idx(N)` to denote the index of `N` in this list. In the list
| submod\<N\> | R | 0101011 | 000 | `idx*8+1` | `[rd: N::NUM_LIMBS]_2 = [rs1: N::NUM_LIMBS]_2 - [rs2: N::NUM_LIMBS]_2 (mod N)` |
| mulmod\<N\> | R | 0101011 | 000 | `idx*8+2` | `[rd: N::NUM_LIMBS]_2 = [rs1: N::NUM_LIMBS]_2 * [rs2: N::NUM_LIMBS]_2 (mod N)` |
| divmod\<N\> | R | 0101011 | 000 | `idx*8+3` | `[rd: N::NUM_LIMBS]_2 = [rs1: N::NUM_LIMBS]_2 / [rs2: N::NUM_LIMBS]_2 (mod N)` (undefined when `gcd([rs2: N::NUM_LIMBS]_2, N) != 1`) |
| iseqmod\<N\> | R | 0101011 | 000 | `idx*8+4` | `rd = [rs1: N::NUM_LIMBS]_2 == [rs2: N::NUM_LIMBS]_2 (mod N) ? 1 : 0`. Enforces that `[rs1: N::NUM_LIMBS]_2` and `[rs2: N::NUM_LIMBS]_2` are both less than `N` and then sets `rd` equal to boolean comparison value. |
| setup\<N\> | R | 0101011 | 000 | `idx*8+5` | `assert([rs1: N::NUM_LIMBS]_2 == N)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes something (can be anything) into `[rd: N::NUM_LIMBS]_2` if `ind(rs2) = 0,1` (for add_sub, mul_div) or it overwrites the register value of `rd` if `ind(rs2) = 2` (for iseq). |
| iseqmod\<N\> | R | 0101011 | 000 | `idx*8+4` | `rd = [rs1: N::NUM_LIMBS]_2 == [rs2: N::NUM_LIMBS]_2 (mod N) ? 1 : 0`. If `rd != x0`, enforces that `[rs1: N::NUM_LIMBS]_2` and `[rs2: N::NUM_LIMBS]_2` are both less than `N` and then sets `rd` equal to boolean comparison value. If `rd = x0`, this is a no-op. |
| setup\<N\> | R | 0101011 | 000 | `idx*8+5` | `assert([rs1: N::NUM_LIMBS]_2 == N)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: N::NUM_LIMBS]_2` if `ind(rs2) = 0,1` (for add_sub, mul_div) or it overwrites the register value of `rd` with an unconstrained value if `ind(rs2) = 2` (for iseq). If `ind(rs2) = 2`, then the instruction is **invalid** if `rd = x0`. |

Since `funct7` is 7-bits, up to 16 moduli can be supported simultaneously. We use `idx*8` to leave some room for future expansion.

Expand All @@ -104,7 +104,7 @@ Complex extension field arithmetic over `Fp2` depends on `Fp` where `-1` is not
| subcomplex | R | 0101011 | 010 | `idx*8+1` | Read `x: Fp2` from `[rs1..]_2` and `y: Fp2` from `[rs2..]_2`. Write `x - y` to `[rd..]_2` |
| mulcomplex | R | 0101011 | 010 | `idx*8+2` | Read `x: Fp2` from `[rs1..]_2` and `y: Fp2` from `[rs2..]_2`. Write `x * y` to `[rd..]_2` |
| divcomplex | R | 0101011 | 010 | `idx*8+3` | Read `x: Fp2` from `[rs1..]_2` and `y: Fp2` from `[rs2..]_2`. Write `x / y` to `[rd..]_2` |
| setupcomplex| R | 0101011 | 010 | `idx*8+4` | `assert([rs1: Fp::NUM_LIMBS]_2 == Fp::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes something (can be anything) into `[rd: Fp::NUM_LIMBS]_2`. |
| setupcomplex| R | 0101011 | 010 | `idx*8+4` | `assert([rs1: Fp::NUM_LIMBS]_2 == Fp::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: Fp::NUM_LIMBS]_2`. |

## Elliptic Curve Extension

Expand All @@ -114,7 +114,7 @@ The elliptic curve extension supports arithmetic over short Weierstrass curves,
| --------------- | --- | ----------- | ------ | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| sw_add_ne\<C\> | R | 0101011 | 001 | `idx*8` | `EcPoint([rd:2*C::COORD_SIZE]_2) = EcPoint([rs1:2*C::COORD_SIZE]_2) + EcPoint([rs2:2*C::COORD_SIZE]_2)`. Assumes that input affine points are not identity and do not have same x-coordinate. |
| sw_double\<C\> | R | 0101011 | 001 | `idx*8+1` | `EcPoint([rd:2*C::COORD_SIZE]_2) = 2 * EcPoint([rs1:2*C::COORD_SIZE]_2)`. Assumes that input affine point is not identity. `rs2` is unused and must be set to `x0`. |
| setup\<C\> | R | 0101011 | 001 | `idx*8+2` | `assert([rs1: C::COORD_SIZE]_2 == C::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes something (can be anything) into `[rd: 2*C::COORD_SIZE]_2`. If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` |
| setup\<C\> | R | 0101011 | 001 | `idx*8+2` | `assert([rs1: C::COORD_SIZE]_2 == C::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` |
| hint_decompress | R | 0101011 | 001 | `idx*8+3` | Read `x: C::Fp` from `[rs1: C::COORD_SIZE]_2` and `rec_id: u8` from `[rs2]_2`. Reset the hint stream to equal the unique `y: C::Fp` such that `(x, y)` is a point on `C` and `y` has the same parity as `rec_id`, if it exists. Otherwise reset hint stream to arbitrary `C::Fp`. `rd` should be `x0`. |

Since `funct7` is 7-bits, up to 16 curves can be supported simultaneously. We use `idx*8` to leave some room for future expansion.
Expand Down
34 changes: 22 additions & 12 deletions extensions/algebra/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,22 @@ impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
2 => Rv32ModularArithmeticOpcode::SETUP_ISEQ,
_ => panic!("invalid opcode"),
};
Some(Instruction::new(
VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + mod_idx_shift),
F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
F::ZERO, // rs2 = 0
F::ONE, // d_as = 1
F::TWO, // e_as = 2
F::ZERO,
F::ZERO,
))
if local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ && dec_insn.rd == 0 {
panic!("SETUP_ISEQ is not valid for rd = x0");
} else {
Some(Instruction::new(
VmOpcode::from_usize(
local_opcode.global_opcode().as_usize() + mod_idx_shift,
),
F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
F::ZERO, // rs2 = 0
F::ONE, // d_as = 1
F::TWO, // e_as = 2
F::ZERO,
F::ZERO,
))
}
} else {
let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) {
Some(ModArithBaseFunct7::AddMod) => {
Expand All @@ -119,7 +125,11 @@ impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
_ => unimplemented!(),
};
let global_opcode = global_opcode + mod_idx_shift;
Some(from_r_type(global_opcode, 2, &dec_insn))
// The only opcode in this extension which can write to rd is `IsEqMod`
// so we cannot allow rd to be zero in this case.
let allow_rd_zero =
ModArithBaseFunct7::from_repr(base_funct7) != Some(ModArithBaseFunct7::IsEqMod);
Some(from_r_type(global_opcode, 2, &dec_insn, allow_rd_zero))
}
};
instruction.map(TranspilerOutput::one_to_one)
Expand Down Expand Up @@ -189,7 +199,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for Fp2TranspilerExtension {
_ => unimplemented!(),
};
let global_opcode = global_opcode + complex_idx_shift;
Some(from_r_type(global_opcode, 2, &dec_insn))
Some(from_r_type(global_opcode, 2, &dec_insn, true))
}
};
instruction.map(TranspilerOutput::one_to_one)
Expand Down
2 changes: 1 addition & 1 deletion extensions/bigint/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for Int256TranspilerExtension {
}
_ => unimplemented!(),
};
Some(from_r_type(global_opcode, 2, &dec_insn))
Some(from_r_type(global_opcode, 2, &dec_insn, true))
}
BEQ256_FUNCT3 => {
let dec_insn = BType::new(instruction_u32);
Expand Down
2 changes: 1 addition & 1 deletion extensions/ecc/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for EccTranspilerExtension {
_ => unimplemented!(),
};
let global_opcode = global_opcode + curve_idx_shift;
Some(from_r_type(global_opcode, 2, &dec_insn))
Some(from_r_type(global_opcode, 2, &dec_insn, true))
}
};
instruction.map(TranspilerOutput::one_to_one)
Expand Down
1 change: 1 addition & 0 deletions extensions/keccak256/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for Keccak256TranspilerExtension {
Rv32KeccakOpcode::KECCAK256.global_opcode().as_usize(),
2,
&dec_insn,
true,
);
Some(TranspilerOutput::one_to_one(instruction))
}
Expand Down
1 change: 1 addition & 0 deletions extensions/pairing/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for PairingTranspilerExtension {
global_opcode,
2,
&dec_insn,
true,
)))
}
}
120 changes: 103 additions & 17 deletions extensions/rv32im/transpiler/src/rrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,67 +24,112 @@ impl<F: PrimeField32> InstructionProcessor for InstructionTranspiler<F> {
type InstructionResult = Instruction<F>;

fn process_add(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(BaseAluOpcode::ADD.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
BaseAluOpcode::ADD.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_addi(&mut self, dec_insn: IType) -> Self::InstructionResult {
from_i_type(BaseAluOpcode::ADD.global_opcode().as_usize(), &dec_insn)
}

fn process_sub(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(BaseAluOpcode::SUB.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
BaseAluOpcode::SUB.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_xor(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(BaseAluOpcode::XOR.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
BaseAluOpcode::XOR.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_xori(&mut self, dec_insn: IType) -> Self::InstructionResult {
from_i_type(BaseAluOpcode::XOR.global_opcode().as_usize(), &dec_insn)
}

fn process_or(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(BaseAluOpcode::OR.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
BaseAluOpcode::OR.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_ori(&mut self, dec_insn: IType) -> Self::InstructionResult {
from_i_type(BaseAluOpcode::OR.global_opcode().as_usize(), &dec_insn)
}

fn process_and(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(BaseAluOpcode::AND.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
BaseAluOpcode::AND.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_andi(&mut self, dec_insn: IType) -> Self::InstructionResult {
from_i_type(BaseAluOpcode::AND.global_opcode().as_usize(), &dec_insn)
}

fn process_sll(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(ShiftOpcode::SLL.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
ShiftOpcode::SLL.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_slli(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult {
from_i_type_shamt(ShiftOpcode::SLL.global_opcode().as_usize(), &dec_insn)
}

fn process_srl(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(ShiftOpcode::SRL.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
ShiftOpcode::SRL.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_srli(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult {
from_i_type_shamt(ShiftOpcode::SRL.global_opcode().as_usize(), &dec_insn)
}

fn process_sra(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(ShiftOpcode::SRA.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
ShiftOpcode::SRA.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_srai(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult {
from_i_type_shamt(ShiftOpcode::SRA.global_opcode().as_usize(), &dec_insn)
}

fn process_slt(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(LessThanOpcode::SLT.global_opcode().as_usize(), 1, &dec_insn)
from_r_type(
LessThanOpcode::SLT.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

fn process_slti(&mut self, dec_insn: IType) -> Self::InstructionResult {
Expand All @@ -96,6 +141,7 @@ impl<F: PrimeField32> InstructionProcessor for InstructionTranspiler<F> {
LessThanOpcode::SLTU.global_opcode().as_usize(),
1,
&dec_insn,
false,
)
}

Expand Down Expand Up @@ -239,35 +285,75 @@ impl<F: PrimeField32> InstructionProcessor for InstructionTranspiler<F> {
}

fn process_mul(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(MulOpcode::MUL.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
MulOpcode::MUL.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_mulh(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(MulHOpcode::MULH.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
MulHOpcode::MULH.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_mulhu(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(MulHOpcode::MULHU.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
MulHOpcode::MULHU.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_mulhsu(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(MulHOpcode::MULHSU.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
MulHOpcode::MULHSU.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_div(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(DivRemOpcode::DIV.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
DivRemOpcode::DIV.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_divu(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(DivRemOpcode::DIVU.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
DivRemOpcode::DIVU.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_rem(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(DivRemOpcode::REM.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
DivRemOpcode::REM.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_remu(&mut self, dec_insn: RType) -> Self::InstructionResult {
from_r_type(DivRemOpcode::REMU.global_opcode().as_usize(), 0, &dec_insn)
from_r_type(
DivRemOpcode::REMU.global_opcode().as_usize(),
0,
&dec_insn,
false,
)
}

fn process_fence(&mut self, dec_insn: IType) -> Self::InstructionResult {
Expand Down
1 change: 1 addition & 0 deletions extensions/sha256/transpiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl<F: PrimeField32> TranspilerExtension<F> for Sha256TranspilerExtension {
Rv32Sha256Opcode::SHA256.global_opcode().as_usize(),
RV32_MEMORY_AS as usize,
&dec_insn,
true,
);
Some(TranspilerOutput::one_to_one(instruction))
}
Expand Down

0 comments on commit dc9b8ae

Please sign in to comment.