From 9e7a21465eb0e38d83087b2929842a7e15905a24 Mon Sep 17 00:00:00 2001 From: IlyasRidhuan Date: Thu, 13 Jun 2024 13:29:21 +0000 Subject: [PATCH] feat(avm): msm blackbox --- avm-transpiler/src/opcodes.rs | 2 + avm-transpiler/src/transpile.rs | 23 +++ .../contracts/avm_test_contract/src/main.nr | 10 +- yarn-project/foundation/src/fields/point.ts | 10 +- yarn-project/simulator/src/avm/avm_gas.ts | 1 + .../simulator/src/avm/avm_simulator.test.ts | 12 ++ .../src/avm/opcodes/multi_scalar_mul.test.ts | 142 ++++++++++++++++++ .../src/avm/opcodes/multi_scalar_mul.ts | 116 ++++++++++++++ .../serialization/bytecode_serialization.ts | 2 + .../instruction_serialization.ts | 1 + 10 files changed, 317 insertions(+), 2 deletions(-) create mode 100644 yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts create mode 100644 yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts diff --git a/avm-transpiler/src/opcodes.rs b/avm-transpiler/src/opcodes.rs index e4d642dd8298..11cd956237dc 100644 --- a/avm-transpiler/src/opcodes.rs +++ b/avm-transpiler/src/opcodes.rs @@ -73,6 +73,7 @@ pub enum AvmOpcode { SHA256, // temp - may be removed, but alot of contracts rely on it PEDERSEN, // temp - may be removed, but alot of contracts rely on it ECADD, + MSM, // Conversions TORADIXLE, } @@ -165,6 +166,7 @@ impl AvmOpcode { AvmOpcode::SHA256 => "SHA256 ", AvmOpcode::PEDERSEN => "PEDERSEN", AvmOpcode::ECADD => "ECADD", + AvmOpcode::MSM => "MSM", // Conversions AvmOpcode::TORADIXLE => "TORADIXLE", } diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index 67d0f8043a6e..a1ae1b716443 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -855,6 +855,29 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B ], ..Default::default() }), + // Temporary while we dont have efficient noir implementations + BlackBoxOp::MultiScalarMul { points, scalars, outputs } => { + // The length of the scalars vector is 2x the length of the points vector due to limb + // decomposition + let points_offset = points.pointer.0; + let num_points = points.size.0; + let scalars_offset = scalars.pointer.0; + // Output array is fixed to 3 + let outputs_offset = outputs.pointer.0; + avm_instrs.push(AvmInstruction { + opcode: AvmOpcode::MSM, + indirect: Some( + ZEROTH_OPERAND_INDIRECT | FIRST_OPERAND_INDIRECT | SECOND_OPERAND_INDIRECT, + ), + operands: vec![ + AvmOperand::U32 { value: points_offset as u32 }, + AvmOperand::U32 { value: scalars_offset as u32 }, + AvmOperand::U32 { value: outputs_offset as u32 }, + AvmOperand::U32 { value: num_points as u32 }, + ], + ..Default::default() + }); + } _ => panic!("Transpiler doesn't know how to process {:?}", operation), } } diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c6271fc42ada..c3d827652f1d 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -24,7 +24,7 @@ contract AvmTest { global big_field_136_bits: Field = 0x991234567890abcdef1234567890abcdef; // Libs - use dep::std::embedded_curve_ops::EmbeddedCurvePoint; + use dep::std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; use dep::aztec::protocol_types::constants::CONTRACT_INSTANCE_LENGTH; use dep::aztec::prelude::{Map, Deserialize}; use dep::aztec::state_vars::PublicMutable; @@ -144,6 +144,14 @@ contract AvmTest { added } + #[aztec(public)] + fn variable_base_msm() -> [Field; 3] { + let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let triple_g = multi_scalar_mul([g], [scalar]); + triple_g + } + /************************************************************************ * Misc ************************************************************************/ diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index b152bffcc6a2..490c1fe2f48e 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -137,8 +137,16 @@ export class Point { return poseidon2Hash(this.toFields()); } + /** + * Check if this is point at infinity. + */ + isInfPoint() { + // Check this + return this.x.isZero(); + } + isOnGrumpkin() { - if (this.isZero()) { + if (this.isInfPoint()) { return true; } diff --git a/yarn-project/simulator/src/avm/avm_gas.ts b/yarn-project/simulator/src/avm/avm_gas.ts index 7802d4177d1e..d951f20fa32b 100644 --- a/yarn-project/simulator/src/avm/avm_gas.ts +++ b/yarn-project/simulator/src/avm/avm_gas.ts @@ -123,6 +123,7 @@ const BaseGasCosts: Record = { [Opcode.SHA256]: DefaultBaseGasCost, [Opcode.PEDERSEN]: DefaultBaseGasCost, [Opcode.ECADD]: DefaultBaseGasCost, + [Opcode.MSM]: DefaultBaseGasCost, // Conversions [Opcode.TORADIXLE]: DefaultBaseGasCost, }; diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index f1e8c98fb904..86d1960577b4 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -108,6 +108,18 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); }); + it('variable msm operations', async () => { + const context = initContext(); + + const bytecode = getAvmTestContractBytecode('variable_base_msm'); + const results = await new AvmSimulator(context).executeBytecode(bytecode); + + expect(results.reverted).toBe(false); + const grumpkin = new Grumpkin(); + const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3)); + expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); + }); + describe('U128 addition and overflows', () => { it('U128 addition', async () => { const calldata: Fr[] = [ diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts new file mode 100644 index 000000000000..83a9b79ca311 --- /dev/null +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts @@ -0,0 +1,142 @@ +import { Fq, Fr } from '@aztec/circuits.js'; +import { Grumpkin } from '@aztec/circuits.js/barretenberg'; + +import { type AvmContext } from '../avm_context.js'; +import { Field, Uint8, Uint32 } from '../avm_memory_types.js'; +import { initContext } from '../fixtures/index.js'; +import { MultiScalarMul } from './multi_scalar_mul.js'; + +describe('MultiScalarMul Opcode', () => { + let context: AvmContext; + + beforeEach(async () => { + context = initContext(); + }); + it('Should (de)serialize correctly', () => { + const buf = Buffer.from([ + MultiScalarMul.opcode, // opcode + 7, // indirect + ...Buffer.from('12345678', 'hex'), // pointsOffset + ...Buffer.from('23456789', 'hex'), // scalars Offset + ...Buffer.from('3456789a', 'hex'), // outputOffset + ...Buffer.from('456789ab', 'hex'), // pointsLengthOffset + ]); + const inst = new MultiScalarMul( + /*indirect=*/ 7, + /*pointsOffset=*/ 0x12345678, + /*scalarsOffset=*/ 0x23456789, + /*outputOffset=*/ 0x3456789a, + /*pointsLengthOffset=*/ 0x456789ab, + ); + + expect(MultiScalarMul.deserialize(buf)).toEqual(inst); + expect(inst.serialize()).toEqual(buf); + }); + + it('Should perform msm correctly - direct', async () => { + const indirect = 0; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + + const pointsOffset = 0; + // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + for (let i = 0; i < points.length; i++) { + const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y + const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf + context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); + context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); + context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); + } + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); + + it('Should perform msm correctly - indirect', async () => { + const indirect = 7; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + + const pointsOffset = 0; + // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + for (let i = 0; i < points.length; i++) { + const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y + const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf + context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); + context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); + context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); + } + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + // Set up the indirect pointers + const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ + const scalarsIndirectOffset = pointsIndirectOffset + 1; + const outputIndirectOffset = scalarsIndirectOffset + 1; + + context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); + context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); + context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); + + await new MultiScalarMul( + indirect, + pointsIndirectOffset, + scalarsIndirectOffset, + outputIndirectOffset, + pointsLengthOffset, + ).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); +}); diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts new file mode 100644 index 000000000000..70a370b231b2 --- /dev/null +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts @@ -0,0 +1,116 @@ +import { Fq, Fr, Point } from '@aztec/circuits.js'; +import { Grumpkin } from '@aztec/circuits.js/barretenberg'; + +import { type AvmContext } from '../avm_context.js'; +import { Field, TypeTag } from '../avm_memory_types.js'; +import { InstructionExecutionError } from '../errors.js'; +import { Opcode, OperandType } from '../serialization/instruction_serialization.js'; +import { Addressing } from './addressing_mode.js'; +import { Instruction } from './instruction.js'; + +export class MultiScalarMul extends Instruction { + static type: string = 'MultiScalarMul'; + static readonly opcode: Opcode = Opcode.MSM; + + // Informs (de)serialization. See Instruction.deserialize. + static readonly wireFormat: OperandType[] = [ + OperandType.UINT8 /* opcode */, + OperandType.UINT8 /* indirect */, + OperandType.UINT32 /* points vector offset */, + OperandType.UINT32 /* scalars vector offset */, + OperandType.UINT32 /* output offset (fixed triplet)*/, + OperandType.UINT32 /* points length offset */, + ]; + + constructor( + private indirect: number, + private pointsOffset: number, + private scalarsOffset: number, + private outputOffset: number, + private pointsLengthOffset: number, + ) { + super(); + } + + public async execute(context: AvmContext): Promise { + const memory = context.machineState.memory.track(this.type); + // Resolve indirects + const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( + [this.pointsOffset, this.scalarsOffset, this.outputOffset], + memory, + ); + + // Length of the points vector should be U32 + memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); + + // Get the size of the unrolled (x, y , inf) points vector + // TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)? + const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); + // Divide by 3 since each point is represented as a triplet to get the number of points + const numPoints = pointsReadLength / 3; + // The tag for each triplet will be (Field, Field, Uint8) + for (let i = 0; i < numPoints; i++) { + const offset = pointsOffset + i * 3; + // Check (Field, Field) + memory.checkTagsRange(TypeTag.FIELD, offset, 2); + // Check Uint8 (inf flag) + memory.checkTag(TypeTag.UINT8, offset + 2); + } + // Get the unrolled (x, y, inf) representing the points + const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); + + // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition + const scalarReadLength = numPoints * 2; + // Get the unrolled scalar (lo & hi) representing the scalars + const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); + memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); + + // Now we need to reconstruct the points and scalars into something we can operate on. + const grumpkinPoints: Point[] = []; + for (let i = 0; i < numPoints; i++) { + const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); + // Include this later when we have a standard for representing infinity + // const isInf = pointsVector[i + 2].toBoolean(); + + if (!p.isOnGrumpkin()) { + throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); + } + grumpkinPoints.push(p); + } + // The scalars are read from memory as Fr elements, which are limbs of Fq elements + // So we need to reconstruct them before performing the scalar multiplications + const scalarFqVector: Fq[] = []; + for (let i = 0; i < numPoints; i++) { + const scalarLo = scalarsVector[2 * i].toFr(); + const scalarHi = scalarsVector[2 * i + 1].toFr(); + const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); + scalarFqVector.push(fqScalar); + } + // TODO: Is there an efficient MSM implementation in ts that we can replace this by? + const grumpkin = new Grumpkin(); + // Zip the points and scalars into pairs + const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); + // Fold the points and scalars into a single point + // We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts + const outputPoint = rest.reduce( + (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), + grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), + ); + // TODO: Check the Infinity flag here + const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO]; + + memory.setSlice( + outputOffset, + output.map(word => new Field(word)), + ); + + const memoryOperations = { + reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, + writes: 3 /* output triplet */, + indirect: this.indirect, + }; + context.machineState.consumeGas(this.gasCost(memoryOperations)); + memory.assert(memoryOperations); + context.machineState.incrementPc(); + } +} diff --git a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts index 0cf22ba0a5c4..f3afe05e0888 100644 --- a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts +++ b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts @@ -53,6 +53,7 @@ import { Version, Xor, } from '../opcodes/index.js'; +import { MultiScalarMul } from '../opcodes/multi_scalar_mul.js'; import { BufferCursor } from './buffer_cursor.js'; import { Opcode } from './instruction_serialization.js'; @@ -143,6 +144,7 @@ const INSTRUCTION_SET = () => [Poseidon2.opcode, Poseidon2], [Sha256.opcode, Sha256], [Pedersen.opcode, Pedersen], + [MultiScalarMul.opcode, MultiScalarMul], // Conversions [ToRadixLE.opcode, ToRadixLE], ]); diff --git a/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts b/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts index d8ccbd918409..0a4ee888fcf2 100644 --- a/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts +++ b/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts @@ -77,6 +77,7 @@ export enum Opcode { SHA256, // temp - may be removed, but alot of contracts rely on it PEDERSEN, // temp - may be removed, but alot of contracts rely on it ECADD, + MSM, // Conversion TORADIXLE, }