diff --git a/crates/torii/core/src/executor.rs b/crates/torii/core/src/executor.rs index a9dee8bdf2..11763f64e0 100644 --- a/crates/torii/core/src/executor.rs +++ b/crates/torii/core/src/executor.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::mem; use anyhow::{Context, Result}; @@ -6,7 +6,7 @@ use dojo_types::schema::{Struct, Ty}; use sqlx::query::Query; use sqlx::sqlite::SqliteArguments; use sqlx::{FromRow, Pool, Sqlite, Transaction}; -use starknet::core::types::Felt; +use starknet::core::types::{Felt, U256}; use tokio::sync::broadcast::{Receiver, Sender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; @@ -14,9 +14,11 @@ use tokio::time::Instant; use tracing::{debug, error}; use crate::simple_broker::SimpleBroker; +use crate::sql::utils::{sql_string_to_u256, u256_to_sql_string, I256}; +use crate::sql::FELT_DELIMITER; use crate::types::{ - Entity as EntityUpdated, Event as EventEmitted, EventMessage as EventMessageUpdated, - Model as ModelRegistered, + ContractType, Entity as EntityUpdated, Event as EventEmitted, + EventMessage as EventMessageUpdated, Model as ModelRegistered, }; pub(crate) const LOG_TARGET: &str = "torii_core::executor"; @@ -46,11 +48,17 @@ pub struct DeleteEntityQuery { pub ty: Ty, } +#[derive(Debug, Clone)] +pub struct ApplyBalanceDiffQuery { + pub erc_cache: HashMap<(ContractType, String), I256>, +} + #[derive(Debug, Clone)] pub enum QueryType { SetEntity(Ty), DeleteEntity(DeleteEntityQuery), EventMessage(Ty), + ApplyBalanceDiff(ApplyBalanceDiffQuery), RegisterModel, StoreEvent, Execute, @@ -59,6 +67,8 @@ pub enum QueryType { #[derive(Debug)] pub struct Executor<'c> { + // Queries should use `transaction` instead of `pool` + // This `pool` is only used to create a new `transaction` pool: Pool, transaction: Transaction<'c, Sqlite>, publish_queue: VecDeque, @@ -252,6 +262,12 @@ impl<'c> Executor<'c> { let event = EventEmitted::from_row(&row)?; self.publish_queue.push_back(BrokerMessage::EventEmitted(event)); } + QueryType::ApplyBalanceDiff(apply_balance_diff) => { + debug!(target: LOG_TARGET, "Applying balance diff."); + let instant = Instant::now(); + self.apply_balance_diff(apply_balance_diff).await?; + debug!(target: LOG_TARGET, duration = ?instant.elapsed(), "Applied balance diff."); + } QueryType::Execute => { debug!(target: LOG_TARGET, "Executing query."); let instant = Instant::now(); @@ -286,6 +302,102 @@ impl<'c> Executor<'c> { Ok(()) } + + async fn apply_balance_diff( + &mut self, + apply_balance_diff: ApplyBalanceDiffQuery, + ) -> Result<()> { + let erc_cache = apply_balance_diff.erc_cache; + for ((contract_type, id_str), balance) in erc_cache.iter() { + let id = id_str.split(FELT_DELIMITER).collect::>(); + match contract_type { + ContractType::WORLD => unreachable!(), + ContractType::ERC721 => { + // account_address/contract_address:id => ERC721 + assert!(id.len() == 2); + let account_address = id[0]; + let token_id = id[1]; + let mid = token_id.split(":").collect::>(); + let contract_address = mid[0]; + + self.apply_balance_diff_helper( + id_str, + account_address, + contract_address, + token_id, + balance, + ) + .await + .with_context(|| "Failed to apply balance diff in apply_cache_diff")?; + } + ContractType::ERC20 => { + // account_address/contract_address/ => ERC20 + assert!(id.len() == 3); + let account_address = id[0]; + let contract_address = id[1]; + let token_id = id[1]; + + self.apply_balance_diff_helper( + id_str, + account_address, + contract_address, + token_id, + balance, + ) + .await + .with_context(|| "Failed to apply balance diff in apply_cache_diff")?; + } + } + } + + Ok(()) + } + + async fn apply_balance_diff_helper( + &mut self, + id: &str, + account_address: &str, + contract_address: &str, + token_id: &str, + balance_diff: &I256, + ) -> Result<()> { + let tx = &mut self.transaction; + let balance: Option<(String,)> = + sqlx::query_as("SELECT balance FROM balances WHERE id = ?") + .bind(id) + .fetch_optional(&mut **tx) + .await?; + + let mut balance = if let Some(balance) = balance { + sql_string_to_u256(&balance.0) + } else { + U256::from(0u8) + }; + + if balance_diff.is_negative { + if balance < balance_diff.value { + dbg!(&balance_diff, balance, id); + } + balance -= balance_diff.value; + } else { + balance += balance_diff.value; + } + + // write the new balance to the database + sqlx::query( + "INSERT OR REPLACE INTO balances (id, contract_address, account_address, token_id, \ + balance) VALUES (?, ?, ?, ?, ?)", + ) + .bind(id) + .bind(contract_address) + .bind(account_address) + .bind(token_id) + .bind(u256_to_sql_string(&balance)) + .execute(&mut **tx) + .await?; + + Ok(()) + } } fn send_broker_message(message: BrokerMessage) { diff --git a/crates/torii/core/src/sql/erc.rs b/crates/torii/core/src/sql/erc.rs index 78e064f258..4e31e03743 100644 --- a/crates/torii/core/src/sql/erc.rs +++ b/crates/torii/core/src/sql/erc.rs @@ -1,13 +1,16 @@ -use anyhow::Result; +use std::collections::HashMap; +use std::mem; + +use anyhow::{Context, Result}; use cainome::cairo_serde::{ByteArray, CairoSerde}; use starknet::core::types::{BlockId, BlockTag, Felt, FunctionCall, U256}; use starknet::core::utils::{get_selector_from_name, parse_cairo_short_string}; use starknet::providers::Provider; use tracing::debug; -use super::utils::{sql_string_to_u256, u256_to_sql_string, I256}; +use super::utils::{u256_to_sql_string, I256}; use super::{Sql, FELT_DELIMITER}; -use crate::executor::{Argument, QueryMessage}; +use crate::executor::{ApplyBalanceDiffQuery, Argument, QueryMessage, QueryType}; use crate::sql::utils::{felt_and_u256_to_sql_string, felt_to_sql_string, felts_to_sql_string}; use crate::types::ContractType; use crate::utils::utc_dt_string_from_timestamp; @@ -30,7 +33,7 @@ impl Sql { if !token_exists { self.register_erc20_token_metadata(contract_address, &token_id, provider).await?; - self.execute().await?; + self.execute().await.with_context(|| "Failed to execute in handle_erc20_transfer")?; } self.store_erc_transfer_event( @@ -332,92 +335,16 @@ impl Sql { } pub async fn apply_cache_diff(&mut self) -> Result<()> { - for ((contract_type, id_str), balance) in self.local_cache.erc_cache.iter() { - let id = id_str.split(FELT_DELIMITER).collect::>(); - match contract_type { - ContractType::WORLD => unreachable!(), - ContractType::ERC721 => { - // account_address/contract_address:id => ERC721 - assert!(id.len() == 2); - let account_address = id[0]; - let token_id = id[1]; - let mid = token_id.split(":").collect::>(); - let contract_address = mid[0]; - - self.apply_balance_diff( - id_str, - account_address, - contract_address, - token_id, - balance, - ) - .await?; - } - ContractType::ERC20 => { - // account_address/contract_address/ => ERC20 - assert!(id.len() == 3); - let account_address = id[0]; - let contract_address = id[1]; - let token_id = id[1]; - - self.apply_balance_diff( - id_str, - account_address, - contract_address, - token_id, - balance, - ) - .await?; - } - } - } - - self.local_cache.erc_cache.clear(); - Ok(()) - } - - async fn apply_balance_diff( - &self, - id: &str, - account_address: &str, - contract_address: &str, - token_id: &str, - balance_diff: &I256, - ) -> Result<()> { - let balance: Option<(String,)> = - sqlx::query_as("SELECT balance FROM balances WHERE id = ?") - .bind(id) - .fetch_optional(&self.pool) - .await?; - - let mut balance = if let Some(balance) = balance { - sql_string_to_u256(&balance.0) - } else { - U256::from(0u8) - }; - - if balance_diff.is_negative { - if balance < balance_diff.value { - dbg!(&balance_diff, balance, id); - } - balance -= balance_diff.value; - } else { - balance += balance_diff.value; - } - - // write the new balance to the database - sqlx::query( - "INSERT OR REPLACE INTO balances (id, contract_address, account_address, token_id, \ - balance) VALUES (?, ?, ?, ?, ?)", - ) - .bind(id) - .bind(contract_address) - .bind(account_address) - .bind(token_id) - .bind(u256_to_sql_string(&balance)) - .execute(&self.pool) - .await?; - + self.executor.send(QueryMessage::new( + "".to_string(), + vec![], + QueryType::ApplyBalanceDiff(ApplyBalanceDiffQuery { + erc_cache: mem::replace( + &mut self.local_cache.erc_cache, + HashMap::with_capacity(64), + ), + }), + ))?; Ok(()) } }