Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(katana): new command line interface (breaking) #2663

Merged
merged 15 commits into from
Nov 10, 2024
350 changes: 215 additions & 135 deletions bin/katana/src/cli/node.rs

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions bin/katana/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::fmt::Display;
use std::path::PathBuf;

use anyhow::{Context, Result};
use clap::builder::PossibleValue;
use clap::ValueEnum;
use katana_primitives::block::{BlockHash, BlockHashOrNumber, BlockNumber};
use katana_primitives::genesis::json::GenesisJson;
use katana_primitives::genesis::Genesis;
Expand Down Expand Up @@ -34,6 +37,34 @@ pub fn parse_block_hash_or_number(value: &str) -> Result<BlockHashOrNumber> {
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum LogFormat {
Json,
Full,
}

impl ValueEnum for LogFormat {
fn value_variants<'a>() -> &'a [Self] {
&[Self::Json, Self::Full]
}

fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
Self::Json => Some(PossibleValue::new("json")),
Self::Full => Some(PossibleValue::new("full")),
}
}
}

impl Display for LogFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Json => write!(f, "json"),
Self::Full => write!(f, "full"),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion crates/dojo/test-utils/src/sequencer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub fn get_default_test_config(sequencing: SequencingConfig) -> Config {
chain.genesis.sequencer_address = *DEFAULT_SEQUENCER_ADDRESS;

let rpc = RpcConfig {
allowed_origins: None,
cors_domain: None,
port: 0,
addr: DEFAULT_RPC_ADDR,
max_connections: DEFAULT_RPC_MAX_CONNECTIONS,
Expand Down
139 changes: 94 additions & 45 deletions crates/katana/node-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,21 @@ pub struct Katana {
json_log: bool,
block_time: Option<u64>,
db_dir: Option<PathBuf>,
rpc_url: Option<String>,
l1_provider: Option<String>,
fork_block_number: Option<u64>,
messaging: Option<PathBuf>,

// Metrics options
metrics: Option<String>,
metrics_addr: Option<SocketAddr>,
metrics_port: Option<u16>,

// Server options
port: Option<u16>,
host: Option<String>,
max_connections: Option<u64>,
allowed_origins: Option<String>,
http_addr: Option<SocketAddr>,
http_port: Option<u16>,
rpc_max_connections: Option<u64>,
http_cors_domain: Option<String>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
http_cors_domain: Option<String>,
http_cors_domains: Option<String>,

can't the input contains several values separated by commas? If yes, wdyt about keeping it plural?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go with cors_origins actually, wdyt?
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah definitely remove the http_ so that it's not confused with the http listen address/port

Copy link
Member Author

@kariy kariy Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least on the CLI arg (ie --http.corsdomain), I prefixed it with http to indicate that the cors domains are only applied to the http server as we may introduce different transports (eg ws) and the idea is to allow configuring cors for each server.

Im a bit premature optimizing for this because its on the cli side to make sure we dont introduce these sorta changes too much. But in the code snippet above, i dont mind renaming it to what you guys suggested. Just want to make sure people are aware of my intention.


// Starknet options
// Dev options
seed: Option<u64>,
accounts: Option<u16>,
disable_fee: bool,
Expand Down Expand Up @@ -253,7 +254,7 @@ impl Katana {

/// Sets the port which will be used when the `katana` instance is launched.
pub fn port<T: Into<u16>>(mut self, port: T) -> Self {
self.port = Some(port.into());
self.http_port = Some(port.into());
self
}

Expand All @@ -271,8 +272,8 @@ impl Katana {
}

/// Sets the RPC URL to fork the network from.
pub fn rpc_url<T: Into<String>>(mut self, rpc_url: T) -> Self {
self.rpc_url = Some(rpc_url.into());
pub fn l1_provider<T: Into<String>>(mut self, rpc_url: T) -> Self {
self.l1_provider = Some(rpc_url.into());
self
}

Expand Down Expand Up @@ -301,27 +302,33 @@ impl Katana {
self
}

/// Enables Prometheus metrics and sets the socket address.
pub fn metrics<T: Into<String>>(mut self, metrics: T) -> Self {
self.metrics = Some(metrics.into());
/// Enables Prometheus metrics and sets the metrics server address.
pub fn metrics_addr<T: Into<SocketAddr>>(mut self, addr: T) -> Self {
self.metrics_addr = Some(addr.into());
self
}

/// Enables Prometheus metrics and sets the metrics server port.
pub fn metrics_port<T: Into<u16>>(mut self, port: T) -> Self {
self.metrics_port = Some(port.into());
self
}

/// Sets the host IP address the server will listen on.
pub fn host<T: Into<String>>(mut self, host: T) -> Self {
self.host = Some(host.into());
pub fn http_addr<T: Into<SocketAddr>>(mut self, addr: T) -> Self {
self.http_addr = Some(addr.into());
self
}

/// Sets the maximum number of concurrent connections allowed.
pub const fn max_connections(mut self, max_connections: u64) -> Self {
self.max_connections = Some(max_connections);
pub const fn rpc_max_connections(mut self, max_connections: u64) -> Self {
self.rpc_max_connections = Some(max_connections);
self
}

/// Enables the CORS layer and sets the allowed origins, separated by commas.
pub fn allowed_origins<T: Into<String>>(mut self, allowed_origins: T) -> Self {
self.allowed_origins = Some(allowed_origins.into());
pub fn http_cors_domain<T: Into<String>>(mut self, allowed_origins: T) -> Self {
self.http_cors_domain = Some(allowed_origins.into());
self
}

Expand Down Expand Up @@ -414,8 +421,14 @@ impl Katana {
let mut cmd = self.program.as_ref().map_or_else(|| Command::new("katana"), Command::new);
cmd.stdout(std::process::Stdio::piped()).stderr(std::process::Stdio::inherit());

let mut port = self.port.unwrap_or(0);
cmd.arg("--port").arg(port.to_string());
if let Some(host) = self.http_addr {
cmd.arg("--http.addr").arg(host.to_string());
}

// In the case where port 0 is set, we will need to extract the actual port number
// from the logs.
let mut port = self.http_port.unwrap_or(5050);
cmd.arg("--http.port").arg(port.to_string());

if self.no_mining {
cmd.arg("--no-mining");
Expand All @@ -428,56 +441,92 @@ impl Katana {
if let Some(db_dir) = self.db_dir {
cmd.arg("--db-dir").arg(db_dir);
}
if let Some(rpc_url) = self.rpc_url {
cmd.arg("--rpc-url").arg(rpc_url);

if let Some(url) = self.l1_provider {
cmd.arg("--l1.provider").arg(url);
}

// Need to make sure that the `--dev` is not being set twice.
let mut is_dev = false;

if self.dev {
cmd.arg("--dev");
is_dev = true;
}

if self.json_log {
cmd.arg("--json-log");
if let Some(seed) = self.seed {
if !is_dev {
cmd.arg("--dev");
is_dev = true;
}

cmd.arg("--dev.seed").arg(seed.to_string());
}

if let Some(fork_block_number) = self.fork_block_number {
cmd.arg("--fork-block-number").arg(fork_block_number.to_string());
if let Some(accounts) = self.accounts {
if !is_dev {
cmd.arg("--dev");
is_dev = true;
}

cmd.arg("--dev.accounts").arg(accounts.to_string());
}

if let Some(messaging) = self.messaging {
cmd.arg("--messaging").arg(messaging);
if self.disable_fee {
if !is_dev {
cmd.arg("--dev");
is_dev = true;
}

cmd.arg("--dev.no-fee");
}

if let Some(metrics) = self.metrics {
cmd.arg("--metrics").arg(metrics);
if self.disable_validate {
if !is_dev {
cmd.arg("--dev");
}

cmd.arg("--dev.no-account-validation");
Comment on lines +484 to +489
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure consistent setting of is_dev flag when adding --dev

Ohayo, sensei! In the disable_validate block, after adding --dev, the is_dev flag isn't set to true, unlike in other similar blocks. This could lead to --dev being added multiple times if multiple dev options are enabled. For consistency and to avoid potential issues, consider setting is_dev = true after adding --dev.

Apply this diff to fix the inconsistency:

            if self.disable_validate {
                if !is_dev {
                    cmd.arg("--dev");
+                   is_dev = true;
                }

                cmd.arg("--dev.no-account-validation");
            }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.disable_validate {
if !is_dev {
cmd.arg("--dev");
}
cmd.arg("--dev.no-account-validation");
if self.disable_validate {
if !is_dev {
cmd.arg("--dev");
is_dev = true;
}
cmd.arg("--dev.no-account-validation");

}

if let Some(host) = self.host {
cmd.arg("--host").arg(host);
if self.json_log {
cmd.args(["--log.format", "json"]);
}

if let Some(max_connections) = self.max_connections {
cmd.arg("--max-connections").arg(max_connections.to_string());
if let Some(fork_block_number) = self.fork_block_number {
cmd.args(["--fork", "--fork.block"]).arg(fork_block_number.to_string());
}

if let Some(allowed_origins) = self.allowed_origins {
cmd.arg("--allowed-origins").arg(allowed_origins);
if let Some(messaging) = self.messaging {
cmd.arg("--messaging").arg(messaging);
}

if let Some(seed) = self.seed {
cmd.arg("--seed").arg(seed.to_string());
// Need to make sure that the `--metrics` is not being set twice.
let mut metrics_enabled = false;

if let Some(addr) = self.metrics_addr {
if !metrics_enabled {
cmd.arg("--metrics");
metrics_enabled = true;
}

cmd.arg("--metrics.addr").arg(addr.to_string());
Comment on lines +507 to +513
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid potential multiple --metrics flags

Ohayo, sensei! In the metrics_addr and metrics_port blocks, you're checking if !metrics_enabled before adding --metrics, but you don't set metrics_enabled = true after adding --metrics in the metrics_port block. This could lead to --metrics being added multiple times. For consistency, consider setting metrics_enabled = true after adding --metrics.

Apply this diff to fix the inconsistency:

            if let Some(port) = self.metrics_port {
                if !metrics_enabled {
                    cmd.arg("--metrics");
+                   metrics_enabled = true;
                }

                cmd.arg("--metrics.port").arg(port.to_string());
            }

Committable suggestion skipped: line range outside the PR's diff.

}

if let Some(accounts) = self.accounts {
cmd.arg("--accounts").arg(accounts.to_string());
if let Some(port) = self.metrics_port {
if !metrics_enabled {
cmd.arg("--metrics");
}

cmd.arg("--metrics.port").arg(port.to_string());
Comment on lines +516 to +521
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential duplicate metrics flag.

Ohayo, sensei! The metrics_port block doesn't set metrics_enabled = true after adding --metrics, which could lead to duplicate flags.

Apply this diff to fix the inconsistency:

        if let Some(port) = self.metrics_port {
            if !metrics_enabled {
                cmd.arg("--metrics");
+               metrics_enabled = true;
            }

            cmd.arg("--metrics.port").arg(port.to_string());
        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if let Some(port) = self.metrics_port {
if !metrics_enabled {
cmd.arg("--metrics");
}
cmd.arg("--metrics.port").arg(port.to_string());
if let Some(port) = self.metrics_port {
if !metrics_enabled {
cmd.arg("--metrics");
metrics_enabled = true;
}
cmd.arg("--metrics.port").arg(port.to_string());

}

if self.disable_fee {
cmd.arg("--disable-fee");
if let Some(max_connections) = self.rpc_max_connections {
cmd.arg("--rpc.max-connections").arg(max_connections.to_string());
}

if self.disable_validate {
cmd.arg("--disable-validate");
if let Some(allowed_origins) = self.http_cors_domain {
cmd.arg("--http.corsdomain").arg(allowed_origins);
}

if let Some(chain_id) = self.chain_id {
Expand Down
18 changes: 16 additions & 2 deletions crates/katana/node/src/config/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

/// Metrics server default address.
pub const DEFAULT_METRICS_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
/// Metrics server default port.
pub const DEFAULT_METRICS_PORT: u16 = 9100;

/// Node metrics configurations.
#[derive(Debug, Clone)]
pub struct MetricsConfig {
/// The address to bind the metrics server to.
pub addr: SocketAddr,
pub addr: IpAddr,
/// The port to bind the metrics server to.
pub port: u16,
}

impl MetricsConfig {
/// Returns the [`SocketAddr`] for the metrics server.
pub fn socket_addr(&self) -> SocketAddr {
SocketAddr::new(self.addr, self.port)
}
}
4 changes: 2 additions & 2 deletions crates/katana/node/src/config/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ pub struct RpcConfig {
pub addr: IpAddr,
pub port: u16,
pub max_connections: u32,
pub allowed_origins: Option<Vec<String>>,
pub apis: HashSet<ApiKind>,
pub cors_domain: Option<Vec<String>>,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Ohayo sensei! The field renaming needs to be synchronized across the codebase

The allowed_origins field is still being used in multiple locations:

  • bin/torii/src/main.rs
  • crates/torii/server/src/proxy.rs
  • crates/katana/node-bindings/src/lib.rs

This inconsistency could lead to confusion and potential bugs. The renaming should be applied consistently across all related components.

🔗 Analysis chain

Ohayo sensei! The field renaming improves semantic clarity.

The change from allowed_origins to cors_domain better reflects the field's purpose in handling Cross-Origin Resource Sharing (CORS) configurations.

Let's verify the impact of this rename:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any remaining references to allowed_origins
# to ensure complete renaming across the codebase

echo "Checking for any remaining references to allowed_origins..."
rg "allowed_origins" 

Length of output: 1405

}

impl RpcConfig {
Expand All @@ -37,7 +37,7 @@ impl RpcConfig {
impl Default for RpcConfig {
fn default() -> Self {
Self {
allowed_origins: None,
cors_domain: None,
addr: DEFAULT_RPC_ADDR,
port: DEFAULT_RPC_PORT,
max_connections: DEFAULT_RPC_MAX_CONNECTIONS,
Expand Down
32 changes: 16 additions & 16 deletions crates/katana/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl Node {

// TODO: maybe move this to the build stage
if let Some(ref cfg) = self.metrics_config {
let addr = cfg.socket_addr();
let mut reports: Vec<Box<dyn Report>> = Vec::new();

if let Some(ref db) = self.db {
Expand All @@ -116,8 +117,8 @@ impl Node {
let exporter = PrometheusRecorder::current().expect("qed; should exist at this point");
let server = MetricsServer::new(exporter).with_process_metrics().with_reports(reports);

self.task_manager.task_spawner().build_task().spawn(server.start(cfg.addr));
info!(addr = %cfg.addr, "Metrics server started.");
self.task_manager.task_spawner().build_task().spawn(server.start(addr));
info!(%addr, "Metrics server started.");
}

let pool = self.pool.clone();
Expand Down Expand Up @@ -312,20 +313,19 @@ pub async fn spawn<EF: ExecutorFactory>(
.allow_methods([Method::POST, Method::GET])
.allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]);

let cors =
config.allowed_origins.clone().map(|allowed_origins| match allowed_origins.as_slice() {
[origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
origins => cors.allow_origin(
origins
.iter()
.map(|o| {
let _ = o.parse::<Uri>().expect("Invalid URI");

o.parse().expect("Invalid origin")
})
.collect::<Vec<_>>(),
),
});
let cors = config.cors_domain.clone().map(|allowed_origins| match allowed_origins.as_slice() {
[origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
origins => cors.allow_origin(
origins
.iter()
.map(|o| {
let _ = o.parse::<Uri>().expect("Invalid URI");

o.parse().expect("Invalid origin")
})
.collect::<Vec<_>>(),
),
});
Comment on lines +316 to +328
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ohayo! Let's improve error handling in CORS configuration, sensei!

The current implementation has a few concerns:

  1. Double parsing: URI is parsed twice (lines 322 and 324)
  2. Use of expect() could cause panics in production
  3. Missing proper error handling for invalid URIs

Here's a suggested improvement:

 let cors = config.cors_domain.clone().map(|allowed_origins| match allowed_origins.as_slice() {
     [origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
     origins => cors.allow_origin(
         origins
             .iter()
-            .map(|o| {
-                let _ = o.parse::<Uri>().expect("Invalid URI");
-
-                o.parse().expect("Invalid origin")
-            })
+            .filter_map(|o| {
+                match o.parse::<Uri>() {
+                    Ok(_) => o.parse().ok(),
+                    Err(e) => {
+                        tracing::warn!("Invalid URI in CORS configuration: {}", e);
+                        None
+                    }
+                }
+            })
             .collect::<Vec<_>>(),
     ),
 });

This change:

  1. Eliminates double parsing
  2. Replaces expect() with proper error handling
  3. Logs invalid URIs instead of panicking
  4. Filters out invalid origins gracefully
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
let cors = config.cors_domain.clone().map(|allowed_origins| match allowed_origins.as_slice() {
[origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
origins => cors.allow_origin(
origins
.iter()
.map(|o| {
let _ = o.parse::<Uri>().expect("Invalid URI");
o.parse().expect("Invalid origin")
})
.collect::<Vec<_>>(),
),
});
let cors = config.cors_domain.clone().map(|allowed_origins| match allowed_origins.as_slice() {
[origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
origins => cors.allow_origin(
origins
.iter()
.filter_map(|o| {
match o.parse::<Uri>() {
Ok(_) => o.parse().ok(),
Err(e) => {
tracing::warn!("Invalid URI in CORS configuration: {}", e);
None
}
}
})
.collect::<Vec<_>>(),
),
});


let middleware = tower::ServiceBuilder::new()
.option_layer(cors)
Expand Down
2 changes: 1 addition & 1 deletion crates/katana/runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl KatanaRunner {
.port(port)
.accounts(n_accounts)
.json_log(true)
.max_connections(10000)
.rpc_max_connections(10000)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Ohayo! We found an inconsistency in the API usage, sensei! 🔍

The method .max_connections() is still being used in crates/katana/node/src/lib.rs while it was renamed to .rpc_max_connections() in the runner. For consistency across the codebase, this should also be updated to use the new method name.

  • crates/katana/node/src/lib.rs: Update .max_connections(config.max_connections) to .rpc_max_connections(config.max_connections)
🔗 Analysis chain

Ohayo! The method rename improves API clarity, sensei! ✨

The change from .max_connections() to .rpc_max_connections() better specifies that this limit applies specifically to RPC connections, which aligns well with the PR's objective of improving configuration clarity.

Let's verify if there are any other occurrences of the old method name that need updating:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any remaining usage of the old method name
rg "\.max_connections\(" --type rust

Length of output: 117


Script:

#!/bin/bash
# Let's check the context of the max_connections usage in the node lib
rg "max_connections" -B 2 -A 2 crates/katana/node/src/lib.rs

# Also check if there's a Config struct definition to understand the field
ast-grep --pattern 'struct Config {
  $$$
  max_connections: $_,
  $$$
}'

Length of output: 320

.dev(config.dev)
.fee(!config.disable_fee);

Expand Down
Loading