diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 3eb8cf035c..b528173968 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -42,6 +42,7 @@ use self::subscriptions::event_message::EventMessageManager; use self::subscriptions::model_diff::{ModelDiffRequest, StateDiffManager}; use crate::proto::types::clause::ClauseType; use crate::proto::types::member_value::ValueType; +use crate::proto::types::LogicalOperator; use crate::proto::world::world_server::WorldServer; use crate::proto::world::{ SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsResponse, @@ -260,7 +261,6 @@ impl DojoWorld { // total count of rows without limit and offset let total_count: u32 = sqlx::query_scalar(&count_query).fetch_optional(&self.pool).await?.unwrap_or(0); - if total_count == 0 { return Ok((Vec::new(), 0)); } @@ -382,7 +382,6 @@ impl DojoWorld { .fetch_optional(&self.pool) .await? .unwrap_or(0); - if total_count == 0 { return Ok((Vec::new(), 0)); } @@ -531,15 +530,13 @@ impl DojoWorld { "#, compute_selector_from_names(namespace, model) ); - - let models_result: Option<(String,)> = - sqlx::query_as(&models_query).fetch_optional(&self.pool).await?; - // we return an empty array of entities if the table is empty - if models_result.is_none() { + let models_str: Option = + sqlx::query_scalar(&models_query).fetch_optional(&self.pool).await?; + if models_str.is_none() { return Ok((Vec::new(), 0)); } - let (models_str,) = models_result.unwrap(); + let models_str = models_str.unwrap(); let model_ids = models_str .split(',') @@ -549,8 +546,14 @@ impl DojoWorld { let schemas = self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); - let table_name = member_clause.model; - let column_name = format!("external_{}", member_clause.member); + let model = member_clause.model.clone(); + let parts: Vec<&str> = member_clause.member.split('.').collect(); + let (table_name, column_name) = if parts.len() > 1 { + let nested_table = parts[..parts.len() - 1].join("$"); + (format!("{model}${nested_table}"), format!("external_{}", parts.last().unwrap())) + } else { + (model, format!("external_{}", member_clause.member)) + }; let (entity_query, arrays_queries, count_query) = build_sql_query( &schemas, table, @@ -566,7 +569,6 @@ impl DojoWorld { .fetch_optional(&self.pool) .await? .unwrap_or(0); - let db_entities = sqlx::query(&entity_query) .bind(comparison_value.clone()) .bind(limit) @@ -587,7 +589,7 @@ impl DojoWorld { Ok((entities_collection, total_count)) } - async fn query_by_composite( + pub(crate) async fn query_by_composite( &self, table: &str, model_relation_table: &str, @@ -596,102 +598,17 @@ impl DojoWorld { limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - // different types of clauses - let mut where_clauses = Vec::new(); - let mut model_clauses: HashMap> = - HashMap::new(); - let mut having_clauses = Vec::new(); - - // bind valeus for prepared statement - let mut bind_values = Vec::new(); - - for clause in composite.clauses { - match clause.clause_type.unwrap() { - ClauseType::HashedKeys(hashed_keys) => { - let ids = hashed_keys - .hashed_keys - .iter() - .map(|id| { - Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id))) - }) - .collect::, Error>>()?; - where_clauses.push(format!("({})", ids.join(" OR "))); - } - ClauseType::Keys(keys) => { - let keys_pattern = build_keys_pattern(&keys)?; - where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'")); - } - ClauseType::Member(member) => { - let comparison_operator = - ComparisonOperator::from_repr(member.operator as usize) - .expect("invalid comparison operator"); - let comparison_value = match member - .value - .ok_or(QueryError::MissingParam("value".into()))? - .value_type - { - Some(ValueType::String(value)) => value, - Some(ValueType::Primitive(value)) => { - let primitive: Primitive = value.try_into()?; - primitive.to_sql_value()? - } - None => return Err(QueryError::MissingParam("value_type".into()).into()), - }; - - let column_name = format!("external_{}", member.member); - - model_clauses.entry(member.model.clone()).or_default().push(( - column_name, - comparison_operator, - comparison_value, - )); - - let (namespace, model) = member - .model - .split_once('-') - .ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; - let model_id: Felt = compute_selector_from_names(namespace, model); - having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id)); - } - _ => return Err(QueryError::UnsupportedQuery.into()), - } - } - - let mut join_clauses = Vec::new(); - for (model, clauses) in model_clauses { - let model_conditions = clauses - .into_iter() - .map(|(column, op, value)| { - bind_values.push(value); - format!("[{}].{} {} ?", model, column, op) - }) - .collect::>() - .join(" AND "); - - join_clauses.push(format!( - "JOIN [{}] ON [{}].id = [{}].entity_id AND ({})", - model, table, model, model_conditions - )); - } - - let join_clause = join_clauses.join(" "); - let where_clause = if !where_clauses.is_empty() { - format!("WHERE {}", where_clauses.join(" AND ")) - } else { - String::new() - }; - let having_clause = if !having_clauses.is_empty() { - format!("HAVING {}", having_clauses.join(" AND ")) - } else { - String::new() - }; + let (where_clause, having_clause, join_clause, bind_values) = + build_composite_clause(table, model_relation_table, &composite)?; let count_query = format!( r#" SELECT COUNT(DISTINCT [{table}].id) FROM [{table}] + JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id {join_clause} {where_clause} + {having_clause} "# ); @@ -701,7 +618,6 @@ impl DojoWorld { } let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); - if total_count == 0 { return Ok((Vec::new(), 0)); } @@ -721,7 +637,7 @@ impl DojoWorld { ); let mut db_query = sqlx::query_as(&query); - for value in bind_values { + for value in &bind_values { db_query = db_query.bind(value); } db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0)); @@ -1042,6 +958,111 @@ fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result Result<(String, String, String, Vec), Error> { + let is_or = composite.operator == LogicalOperator::Or as i32; + let mut where_clauses = Vec::new(); + let mut join_clauses = Vec::new(); + let mut having_clauses = Vec::new(); + let mut bind_values = Vec::new(); + + for clause in &composite.clauses { + match clause.clause_type.as_ref().unwrap() { + ClauseType::HashedKeys(hashed_keys) => { + let ids = hashed_keys + .hashed_keys + .iter() + .map(|id| { + bind_values.push(Felt::from_bytes_be_slice(id).to_string()); + "?".to_string() + }) + .collect::>() + .join(", "); + where_clauses.push(format!("{table}.id IN ({})", ids)); + } + ClauseType::Keys(keys) => { + let keys_pattern = build_keys_pattern(keys)?; + bind_values.push(keys_pattern); + where_clauses.push(format!("{table}.keys REGEXP ?")); + } + ClauseType::Member(member) => { + let comparison_operator = ComparisonOperator::from_repr(member.operator as usize) + .expect("invalid comparison operator"); + let value = member.value.clone(); + let comparison_value = + match value.ok_or(QueryError::MissingParam("value".into()))?.value_type { + Some(ValueType::String(value)) => value, + Some(ValueType::Primitive(value)) => { + let primitive: Primitive = value.try_into()?; + primitive.to_sql_value()? + } + None => return Err(QueryError::MissingParam("value_type".into()).into()), + }; + bind_values.push(comparison_value); + + let model = member.model.clone(); + let parts: Vec<&str> = member.member.split('.').collect(); + let (table_name, column_name) = if parts.len() > 1 { + let nested_table = parts[..parts.len() - 1].join("$"); + ( + format!("[{model}${nested_table}]"), + format!("external_{}", parts.last().unwrap()), + ) + } else { + (format!("[{model}]"), format!("external_{}", member.member)) + }; + + let (namespace, model) = member + .model + .split_once('-') + .ok_or(QueryError::InvalidNamespacedModel(member.model.clone()))?; + let model_id = compute_selector_from_names(namespace, model); + join_clauses.push(format!( + "LEFT JOIN {table_name} ON [{table}].id = {table_name}.entity_id" + )); + where_clauses.push(format!("{table_name}.{column_name} {comparison_operator} ?")); + having_clauses.push(format!( + "INSTR(group_concat({model_relation_table}.model_id), '{:#x}') > 0", + model_id + )); + } + ClauseType::Composite(nested_composite) => { + let (nested_where, nested_having, nested_join, nested_values) = + build_composite_clause(table, model_relation_table, nested_composite)?; + where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE "))); + if !nested_having.is_empty() { + having_clauses.push(nested_having.trim_start_matches("HAVING ").to_string()); + } + join_clauses.extend( + nested_join + .split_whitespace() + .filter(|&s| s.starts_with("LEFT")) + .map(String::from), + ); + bind_values.extend(nested_values); + } + } + } + + let join_clause = join_clauses.join(" "); + let where_clause = if !where_clauses.is_empty() { + format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + let having_clause = if !having_clauses.is_empty() { + format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " })) + } else { + String::new() + }; + + Ok((where_clause, having_clause, join_clause, bind_values)) +} + type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>;