Skip to content

Commit

Permalink
Add RISC-V support for cycle and clock counters (nexus-xyz#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
duc-nx authored Aug 6, 2024
1 parent 3644996 commit fdededf
Show file tree
Hide file tree
Showing 18 changed files with 370 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- run: rustup target add riscv32i-unknown-none-elf
- run: assets/scripts/smoke.sh examples/src/bin/fib3.rs
- run: assets/scripts/smoke.sh examples/src/bin/fib3_profiling.rs
- run: assets/scripts/smoke.sh examples/src/bin/hello.rs

test-sdk:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ Cargo.lock
nexus-proof

# macos
.DS_Store
.DS_Store
4 changes: 2 additions & 2 deletions assets/scripts/smoke.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ cd "$PROJECT_NAME"
# Link the test program to the latest runtime code
sed -e "s#git = \"https://github.com/nexus-xyz/nexus-zkvm.git\"#path = \"$ORIGINAL_DIR/runtime\"#" Cargo.toml > Cargo.tmp && mv Cargo.tmp Cargo.toml

"$ORIGINAL_DIR/target/release/cargo-nexus" nexus run
"$ORIGINAL_DIR/target/release/cargo-nexus" nexus run -v
"$ORIGINAL_DIR/target/release/cargo-nexus" nexus prove
"$ORIGINAL_DIR/target/release/cargo-nexus" nexus verify

cleanup
cleanup
2 changes: 2 additions & 0 deletions cli/src/command/new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ fn setup_crate(path: PathBuf) -> anyhow::Result<()> {
"# Generated by cargo-nexus, do not remove!\n",
"#\n",
"# This profile is used for generating proofs, as Nexus VM support for compiler optimizations is still under development.\n",
"[features]\n",
"cycles = [] # Enable cycle counting for run command\n",
)
)?;

Expand Down
62 changes: 47 additions & 15 deletions cli/src/command/run.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::path::Path;

use clap::Args;
use std::collections::HashSet;
use std::path::Path;

use crate::utils::{cargo, path_to_artifact};

const ALLOWED_FEATURES: [&str; 1] = ["cycles"];

#[derive(Debug, Args)]
pub struct RunArgs {
/// Print instruction trace.
Expand All @@ -17,25 +19,55 @@ pub struct RunArgs {
/// Name of the bin target to run.
#[arg(long)]
pub bin: Option<String>,

/// Build artifacts with the specific features. "cycles" is default.
#[arg(
long,
default_value = "cycles",
value_name = "FEATURES",
use_value_delimiter = true
)]
pub features: Vec<String>,
}

pub fn handle_command(args: RunArgs) -> anyhow::Result<()> {
let RunArgs { verbose, profile, bin } = args;
let RunArgs { verbose, profile, bin, features } = args;

run_vm(bin, verbose, &profile)
run_vm(bin, verbose, &profile, features)
}

fn run_vm(bin: Option<String>, verbose: bool, profile: &str) -> anyhow::Result<()> {
// build artifact
cargo(
None,
[
"build",
"--target=riscv32i-unknown-none-elf",
"--profile",
profile,
],
)?;
fn run_vm(
bin: Option<String>,
verbose: bool,
profile: &str,
features: Vec<String>,
) -> anyhow::Result<()> {
let allowed_features: HashSet<_> = ALLOWED_FEATURES.iter().cloned().collect();

// Build cargo arguments
let mut cargo_args = vec![
"build",
"--target=riscv32i-unknown-none-elf",
"--profile",
profile,
];

// Filter and add valid features
let valid_features: Vec<&&str> = features
.iter()
.filter_map(|f| allowed_features.get(f.as_str()))
.collect();

if !valid_features.is_empty() {
cargo_args.push("--features");
for feature in valid_features {
cargo_args.push(feature);
cargo_args.push(",");
}
cargo_args.pop(); // Remove trailing comma
}

cargo(None, cargo_args)?;

let path = path_to_artifact(bin, profile)?;

Expand Down
3 changes: 3 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ nexus-rt = { path = "../runtime" }
print_with_newline = { level = "allow", priority = 0 }
needless_range_loop = { level = "allow", priority = 0 }
manual_memcpy = { level = "allow", priority = 0 }

[features]
cycles = []
3 changes: 1 addition & 2 deletions examples/src/bin/fib3.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Used in the CI as a small example that uses memory store
#![no_std]
#![no_main]
#![cfg_attr(target_arch = "riscv32", no_std, no_main)]

fn fib(n: u32) -> u32 {
match n {
Expand Down
37 changes: 37 additions & 0 deletions examples/src/bin/fib3_profiling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Used in the CI as a small example that uses memory store
#![cfg_attr(target_arch = "riscv32", no_std, no_main)]

#[nexus_rt::profile]
fn fib(n: u32) -> u32 {
match n {
0 => 0,
1 => 1,
_ => fib(n - 1) + fib(n - 2),
}
}

#[nexus_rt::profile]
fn fib2(n: u32) -> u32 {
if n == 0 {
return 0;
}
if n == 1 {
return 1;
}
let mut a = 0;
let mut b = 1;
let mut result = 0;
for _ in 2..=n {
result = a + b;
a = b;
b = result;
}
result
}

#[nexus_rt::main]
fn main() {
let n = 3;
assert_eq!(fib(n), 2);
assert_eq!(fib2(n), 2);
}
1 change: 1 addition & 0 deletions runtime/macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ proc-macro = true
[dependencies]
quote = "1.0"
proc-macro2 = "1.0"
proc-macro-crate = "3.1.0"
syn = { version = "1.0", features = ["full"] }

[dev-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions runtime/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ use proc_macro::TokenStream;

mod entry;
mod parse_args;
mod profile;

#[proc_macro_attribute]
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
entry::main(args.into(), input.into())
.map(Into::into)
.unwrap_or_else(|err| err.into_compile_error().into())
}

#[proc_macro_attribute]
pub fn profile(_attr: TokenStream, item: TokenStream) -> TokenStream {
profile::profile(_attr.into(), item.into()).into()
}
33 changes: 33 additions & 0 deletions runtime/macros/src/profile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use proc_macro2::TokenStream;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::{format_ident, quote};
use syn::{parse2, ItemFn};

fn get_nexus_rt_ident() -> proc_macro2::Ident {
match crate_name("nexus_rt") {
Ok(FoundCrate::Name(name)) => format_ident!("{}", name),
_ => format_ident!("nexus_rt"),
}
}

pub fn profile(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input: ItemFn = parse2(item).expect("Invalid code block");

let ItemFn {
vis: visibility, sig: signature, block, ..
} = input;

let name: &syn::Ident = &signature.ident;
let nexus_rt = get_nexus_rt_ident();

quote! {
#visibility #signature {
#[cfg(feature = "cycles")]
#nexus_rt::cycle_count_ecall(concat!("^#", file!(), ":", stringify!(#name)));
let result = (|| #block)();
#[cfg(feature = "cycles")]
#nexus_rt::cycle_count_ecall(concat!("$#", file!(), ":", stringify!(#name)));
result
}
}
}
8 changes: 7 additions & 1 deletion runtime/src/ecalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ mod riscv32 {
ecall!(3, b.as_ptr(), b.len(), _out);
}

/// Bench cycles with input is function name
pub fn cycle_count_ecall(s: &str) {
let mut _out: u32;
ecall!(5, s.as_ptr(), s.len(), _out);
}

/// An empty type representing the VM terminal
pub struct NexusLog;

Expand Down Expand Up @@ -121,7 +127,7 @@ pub fn write_output<T: serde::Serialize + ?Sized>(_: &T) {
panic!("output is not available outside of NexusVM")
}

/// Write a slice to the output taoe
/// Write a slice to the output tape
#[cfg(not(target_arch = "riscv32"))]
pub fn write_to_output(_: &[u8]) {
panic!("output is not available outside of NexusVM")
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use runtime::*;
#[cfg(target_arch = "riscv32")]
mod alloc;

pub use nexus_rt_macros::main;
pub use nexus_rt_macros::{main, profile};

mod ecalls;
pub use ecalls::*;
4 changes: 4 additions & 0 deletions vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ pub enum NexusVMError {
/// Reached the limit of executed instructions.
#[error("reached maximum number of executed instructions: {0}")]
MaxTraceLengthExceeded(usize),

/// Benchmark labels are invalid.
#[error("Labels are invalid.")]
InvalidProfileLabel,
}

/// Result type for VM functions that can produce errors
Expand Down
75 changes: 73 additions & 2 deletions vm/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use crate::{
error::*,
memory::Memory,
rv32::{parse::*, *},
syscalls::Syscalls,
syscalls::{SyscallCode, Syscalls},
};

use std::collections::HashSet;
use crate::NexusVMError;

use std::collections::{HashMap, HashSet};

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -39,6 +41,11 @@ pub struct NexusVM<M: Memory> {
///
/// Does include executing UNIMP instruction.
pub max_trace_len: Option<usize>,

/// The cycle count for execution trace.
pub cycle_count: u64,
/// The cycles tracker label: (func_name, (cycle_count, counter))
pub cycle_tracker: HashMap<String, (u64, u32)>,
}

/// ISA defined registers
Expand All @@ -65,6 +72,7 @@ impl<M: Memory> NexusVM<M> {

vm.regs.pc = pc;
vm.instruction_sets = HashSet::new();
vm.cycle_count = 0;

vm
}
Expand Down Expand Up @@ -186,6 +194,48 @@ fn alu_op(aop: AOP, x: u32, y: u32) -> u32 {
}
}

fn handle_profile_cycles(vm: &mut NexusVM<impl Memory>) -> Result<()> {
let label = vm
.syscalls
.get_label()
.and_then(|s| std::str::from_utf8(&s).ok().map(|s| s.to_owned()))
.ok_or(NexusVMError::InvalidProfileLabel)?;

let fn_name = label
.split('#')
.last()
.ok_or(NexusVMError::InvalidProfileLabel)?
.to_owned();

match label.chars().next() {
Some('^') => start_profile(vm, fn_name),
Some('$') => end_profile(vm, fn_name)?,
_ => return Err(NexusVMError::InvalidProfileLabel),
}

Ok(())
}

fn start_profile(vm: &mut NexusVM<impl Memory>, fn_name: String) {
vm.cycle_tracker
.entry(fn_name)
.or_insert((vm.cycle_count, 0))
.1 += 1;
}

fn end_profile(vm: &mut NexusVM<impl Memory>, fn_name: String) -> Result<()> {
let (clk, counter) = vm
.cycle_tracker
.get_mut(&fn_name)
.ok_or(NexusVMError::InvalidProfileLabel)?;

*counter -= 1;
if *counter == 0 {
*clk = vm.cycle_count - *clk;
}
Ok(())
}

/// evaluate next instruction
pub fn eval_inst(vm: &mut NexusVM<impl Memory>) -> Result<()> {
if vm
Expand Down Expand Up @@ -277,17 +327,38 @@ pub fn eval_inst(vm: &mut NexusVM<impl Memory>) -> Result<()> {
ECALL { rd } => {
RD = rd;
vm.Z = vm.syscalls.syscall(vm.regs.pc, vm.regs.x, &vm.mem)?;
// Profile cycles
if vm.regs.x[18] == SyscallCode::ProfileCycles as u32 {
handle_profile_cycles(vm)?;
}
}
UNIMP => {
PC = vm.inst.pc;
}
}

// Counts cycles per instruction kind.
// In RISC-V:
// - Memory instructions are 3 cycles
// - Branch instructions are 2 cycles
// - ALU/ALUI/JAR/LUI/AUIPC instructions are 1 cycle
// - System call instruction are 4 cycles
// - Unknown instruction is 0 cycle
let cycles = match vm.inst.inst {
LOAD { .. } | STORE { .. } => 3,
BR { .. } => 2,
LUI { .. } | AUIPC { .. } | JAL { .. } | JALR { .. } | ALUI { .. } | ALU { .. } => 1,
ECALL { .. } => 4,
EBREAK { .. } => 4,
_ => 0,
};

if PC == 0 {
PC = add32(vm.inst.pc, vm.inst.len);
}
vm.set_reg(RD, vm.Z);
vm.regs.pc = PC;
vm.trace_len += 1;
vm.cycle_count += cycles;
Ok(())
}
Loading

0 comments on commit fdededf

Please sign in to comment.