From 0ccf72c83db3a453580f6fbc1e50536f104022f4 Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Sun, 10 Nov 2024 10:44:42 +0100 Subject: [PATCH 1/3] rename --- agdb_server/src/db_pool.rs | 46 +++++------ .../src/db_pool/{server_db.rs => user_db.rs} | 8 +- ...erver_db_storage.rs => user_db_storage.rs} | 76 +++++++++---------- 3 files changed, 65 insertions(+), 65 deletions(-) rename agdb_server/src/db_pool/{server_db.rs => user_db.rs} (77%) rename agdb_server/src/db_pool/{server_db_storage.rs => user_db_storage.rs} (72%) diff --git a/agdb_server/src/db_pool.rs b/agdb_server/src/db_pool.rs index 8ab08e8c..0226903e 100644 --- a/agdb_server/src/db_pool.rs +++ b/agdb_server/src/db_pool.rs @@ -1,8 +1,8 @@ -mod server_db; -mod server_db_storage; +mod user_db; +mod user_db_storage; use crate::config::Config; -use crate::db_pool::server_db_storage::ServerDbStorage; +use crate::db_pool::user_db_storage::UserDbStorage; use crate::error_code::ErrorCode; use crate::password; use crate::password::Password; @@ -35,8 +35,6 @@ use agdb_api::QueryAudit; use agdb_api::ServerDatabase; use agdb_api::UserStatus; use axum::http::StatusCode; -use server_db::ServerDb; -use server_db::ServerDbImpl; use std::collections::HashMap; use std::io::Seek; use std::io::SeekFrom; @@ -49,6 +47,8 @@ use std::time::UNIX_EPOCH; use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; use tokio::sync::RwLockWriteGuard; +use user_db::ServerDbImpl; +use user_db::UserDb; use uuid::Uuid; #[derive(UserValue)] @@ -69,8 +69,8 @@ struct Database { } pub(crate) struct DbPoolImpl { - server_db: ServerDb, - pool: RwLock>, + server_db: UserDb, + pool: RwLock>, } #[derive(Clone)] @@ -83,7 +83,7 @@ impl DbPool { .join("agdb_server.agdb") .exists(); let db_pool = Self(Arc::new(DbPoolImpl { - server_db: ServerDb::new(&format!("mapped:{}/agdb_server.agdb", config.data_dir))?, + server_db: UserDb::new(&format!("mapped:{}/agdb_server.agdb", config.data_dir))?, pool: RwLock::new(HashMap::new()), })); @@ -157,7 +157,7 @@ impl DbPool { let db_path = db_file(owner, db_name, config); std::fs::create_dir_all(db_audit_dir(owner, config))?; let server_db = - ServerDb::new(&format!("{}:{}", db.db_type, db_path.to_string_lossy()))?; + UserDb::new(&format!("{}:{}", db.db_type, db_path.to_string_lossy()))?; db_pool.0.pool.write().await.insert(db.name, server_db); } } else { @@ -217,7 +217,7 @@ impl DbPool { std::fs::create_dir_all(db_audit_dir(owner, config))?; let path = db_path.to_str().ok_or(ErrorCode::DbInvalid)?.to_string(); - let server_db = ServerDb::new(&format!("{}:{}", db_type, path)).map_err(|mut e| { + let server_db = UserDb::new(&format!("{}:{}", db_type, path)).map_err(|mut e| { e.status = ErrorCode::DbInvalid.into(); e.description = format!("{}: {}", ErrorCode::DbInvalid.as_str(), e.description); e @@ -470,7 +470,7 @@ impl DbPool { let server_db = pool .get_mut(&database.name) .ok_or(db_not_found(&database.name))?; - *server_db = ServerDb::new(&format!("{}:{}", DbType::Memory, database.name))?; + *server_db = UserDb::new(&format!("{}:{}", DbType::Memory, database.name))?; if database.db_type != DbType::Memory { let main_file = db_file(owner, db, config); if main_file.exists() { @@ -484,7 +484,7 @@ impl DbPool { let db_path = Path::new(&config.data_dir).join(&db_name); let path = db_path.to_str().ok_or(ErrorCode::DbInvalid)?.to_string(); - *server_db = ServerDb::new(&format!("{}:{}", database.db_type, path))?; + *server_db = UserDb::new(&format!("{}:{}", database.db_type, path))?; } Ok(()) @@ -495,7 +495,7 @@ impl DbPool { owner: &str, db: &str, config: &Config, - server_db: &ServerDb, + server_db: &UserDb, database: &mut Database, ) -> Result<(), ServerError> { let backup_path = if database.db_type == DbType::Memory { @@ -540,7 +540,7 @@ impl DbPool { pool.remove(&db_name); let current_path = db_file(owner, db, config); - let server_db = ServerDb::new(&format!("{}:{}", db_type, current_path.to_string_lossy()))?; + let server_db = UserDb::new(&format!("{}:{}", db_type, current_path.to_string_lossy()))?; pool.insert(db_name, server_db); database.db_type = db_type; @@ -1009,7 +1009,7 @@ impl DbPool { owner: &str, db: &str, user: DbId, - ) -> ServerResult { + ) -> ServerResult { let user_name = self.user_name(user).await?; if owner != user_name { @@ -1084,7 +1084,7 @@ impl DbPool { .await?; } - let server_db = ServerDb( + let server_db = UserDb( self.get_pool() .await .get(&db_name) @@ -1165,7 +1165,7 @@ impl DbPool { database.backup = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); } - let server_db = ServerDb::new(&format!( + let server_db = UserDb::new(&format!( "{}:{}", database.db_type, current_path.to_string_lossy() @@ -1400,11 +1400,11 @@ impl DbPool { .into()) } - async fn get_pool(&self) -> RwLockReadGuard> { + async fn get_pool(&self) -> RwLockReadGuard> { self.0.pool.read().await } - async fn get_pool_mut(&self) -> RwLockWriteGuard> { + async fn get_pool_mut(&self) -> RwLockWriteGuard> { self.0.pool.write().await } @@ -1505,7 +1505,7 @@ fn required_role(queries: &Queries) -> DbUserRole { } fn t_exec( - t: &Transaction, + t: &Transaction, q: &mut QueryType, results: &[QueryResult], ) -> Result { @@ -1539,7 +1539,7 @@ fn t_exec( } fn t_exec_mut( - t: &mut TransactionMut, + t: &mut TransactionMut, mut q: QueryType, results: &[QueryResult], audit: &mut Vec, @@ -1716,13 +1716,13 @@ fn audit_query(user: &str, audit: &mut Vec, query: QueryType) { #[cfg(test)] mod tests { use super::*; - use crate::db_pool::server_db::ServerDb; + use crate::db_pool::user_db::UserDb; use agdb::QueryBuilder; #[tokio::test] #[should_panic] async fn unreachable() { - let db = ServerDb::new("memory:test").unwrap(); + let db = UserDb::new("memory:test").unwrap(); db.get() .await .transaction(|t| { diff --git a/agdb_server/src/db_pool/server_db.rs b/agdb_server/src/db_pool/user_db.rs similarity index 77% rename from agdb_server/src/db_pool/server_db.rs rename to agdb_server/src/db_pool/user_db.rs index 03150a44..554fc6c6 100644 --- a/agdb_server/src/db_pool/server_db.rs +++ b/agdb_server/src/db_pool/user_db.rs @@ -1,4 +1,4 @@ -use crate::db_pool::server_db_storage::ServerDbStorage; +use crate::db_pool::user_db_storage::UserDbStorage; use crate::server_error::ServerResult; use agdb::DbImpl; use std::sync::Arc; @@ -6,10 +6,10 @@ use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; use tokio::sync::RwLockWriteGuard; -pub(crate) type ServerDbImpl = DbImpl; -pub(crate) struct ServerDb(pub(crate) Arc>); +pub(crate) type ServerDbImpl = DbImpl; +pub(crate) struct UserDb(pub(crate) Arc>); -impl ServerDb { +impl UserDb { pub(crate) fn new(name: &str) -> ServerResult { Ok(Self(Arc::new(RwLock::new(ServerDbImpl::new(name)?)))) } diff --git a/agdb_server/src/db_pool/server_db_storage.rs b/agdb_server/src/db_pool/user_db_storage.rs similarity index 72% rename from agdb_server/src/db_pool/server_db_storage.rs rename to agdb_server/src/db_pool/user_db_storage.rs index 297a84b8..3b239f29 100644 --- a/agdb_server/src/db_pool/server_db_storage.rs +++ b/agdb_server/src/db_pool/user_db_storage.rs @@ -5,50 +5,50 @@ use agdb::MemoryStorage; use agdb::StorageData; use agdb::StorageSlice; -pub(crate) enum ServerDbStorage { +pub(crate) enum UserDbStorage { MemoryMapped(FileStorageMemoryMapped), Memory(MemoryStorage), File(FileStorage), } -impl StorageData for ServerDbStorage { +impl StorageData for UserDbStorage { fn backup(&self, name: &str) -> Result<(), DbError> { match self { - ServerDbStorage::MemoryMapped(s) => s.backup(name), - ServerDbStorage::Memory(s) => s.backup(name), - ServerDbStorage::File(s) => s.backup(name), + UserDbStorage::MemoryMapped(s) => s.backup(name), + UserDbStorage::Memory(s) => s.backup(name), + UserDbStorage::File(s) => s.backup(name), } } fn copy(&self, name: &str) -> Result { Ok(match self { - ServerDbStorage::MemoryMapped(s) => ServerDbStorage::MemoryMapped(s.copy(name)?), - ServerDbStorage::Memory(s) => ServerDbStorage::Memory(s.copy(name)?), - ServerDbStorage::File(s) => ServerDbStorage::File(s.copy(name)?), + UserDbStorage::MemoryMapped(s) => UserDbStorage::MemoryMapped(s.copy(name)?), + UserDbStorage::Memory(s) => UserDbStorage::Memory(s.copy(name)?), + UserDbStorage::File(s) => UserDbStorage::File(s.copy(name)?), }) } fn flush(&mut self) -> Result<(), DbError> { match self { - ServerDbStorage::MemoryMapped(s) => s.flush(), - ServerDbStorage::Memory(s) => s.flush(), - ServerDbStorage::File(s) => s.flush(), + UserDbStorage::MemoryMapped(s) => s.flush(), + UserDbStorage::Memory(s) => s.flush(), + UserDbStorage::File(s) => s.flush(), } } fn len(&self) -> u64 { match self { - ServerDbStorage::MemoryMapped(s) => s.len(), - ServerDbStorage::Memory(s) => s.len(), - ServerDbStorage::File(s) => s.len(), + UserDbStorage::MemoryMapped(s) => s.len(), + UserDbStorage::Memory(s) => s.len(), + UserDbStorage::File(s) => s.len(), } } fn name(&self) -> &str { match self { - ServerDbStorage::MemoryMapped(s) => s.name(), - ServerDbStorage::Memory(s) => s.name(), - ServerDbStorage::File(s) => s.name(), + UserDbStorage::MemoryMapped(s) => s.name(), + UserDbStorage::Memory(s) => s.name(), + UserDbStorage::File(s) => s.name(), } } @@ -68,41 +68,41 @@ impl StorageData for ServerDbStorage { fn read(&self, pos: u64, value_len: u64) -> Result { match self { - ServerDbStorage::MemoryMapped(s) => s.read(pos, value_len), - ServerDbStorage::Memory(s) => s.read(pos, value_len), - ServerDbStorage::File(s) => s.read(pos, value_len), + UserDbStorage::MemoryMapped(s) => s.read(pos, value_len), + UserDbStorage::Memory(s) => s.read(pos, value_len), + UserDbStorage::File(s) => s.read(pos, value_len), } } fn rename(&mut self, new_name: &str) -> Result<(), DbError> { match self { - ServerDbStorage::MemoryMapped(s) => s.rename(new_name), - ServerDbStorage::Memory(s) => s.rename(new_name), - ServerDbStorage::File(s) => s.rename(new_name), + UserDbStorage::MemoryMapped(s) => s.rename(new_name), + UserDbStorage::Memory(s) => s.rename(new_name), + UserDbStorage::File(s) => s.rename(new_name), } } fn resize(&mut self, new_len: u64) -> Result<(), DbError> { match self { - ServerDbStorage::MemoryMapped(s) => s.resize(new_len), - ServerDbStorage::Memory(s) => s.resize(new_len), - ServerDbStorage::File(s) => s.resize(new_len), + UserDbStorage::MemoryMapped(s) => s.resize(new_len), + UserDbStorage::Memory(s) => s.resize(new_len), + UserDbStorage::File(s) => s.resize(new_len), } } fn write(&mut self, pos: u64, bytes: &[u8]) -> Result<(), DbError> { match self { - ServerDbStorage::MemoryMapped(s) => s.write(pos, bytes), - ServerDbStorage::Memory(s) => s.write(pos, bytes), - ServerDbStorage::File(s) => s.write(pos, bytes), + UserDbStorage::MemoryMapped(s) => s.write(pos, bytes), + UserDbStorage::Memory(s) => s.write(pos, bytes), + UserDbStorage::File(s) => s.write(pos, bytes), } } fn is_empty(&self) -> bool { match self { - ServerDbStorage::MemoryMapped(s) => s.is_empty(), - ServerDbStorage::Memory(s) => s.is_empty(), - ServerDbStorage::File(s) => s.is_empty(), + UserDbStorage::MemoryMapped(s) => s.is_empty(), + UserDbStorage::Memory(s) => s.is_empty(), + UserDbStorage::File(s) => s.is_empty(), } } } @@ -126,7 +126,7 @@ mod tests { } } - impl std::fmt::Debug for ServerDbStorage { + impl std::fmt::Debug for UserDbStorage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::MemoryMapped(_) => f.write_str("MemoryMapped"), @@ -145,7 +145,7 @@ mod tests { let test_file_copy_dot = TestFile::new(".file_storage_rename_copy.agdb"); let _test_file_rename_dot = TestFile::new(".file_storage_rename.agdb"); let test_file_backup = TestFile::new("file_storage_backup.agdb"); - let mut storage = ServerDbStorage::new(&format!("file:{}", test_file.0))?; + let mut storage = UserDbStorage::new(&format!("file:{}", test_file.0))?; let _ = format!("{:?}", storage); storage.backup(&test_file_backup.0)?; assert!(std::path::Path::new(&test_file_backup.0).exists()); @@ -173,7 +173,7 @@ mod tests { let test_file_copy_dot = TestFile::new(".mapped_storage_copy.agdb"); let _test_file_rename_dot = TestFile::new(".mapped_storage_rename.agdb"); let test_file2 = TestFile::new("mapped_storage_backup.agdb"); - let mut storage = ServerDbStorage::new(&format!("mapped:{}", test_file.0))?; + let mut storage = UserDbStorage::new(&format!("mapped:{}", test_file.0))?; let _ = format!("{:?}", storage); storage.backup(&test_file2.0)?; assert!(std::path::Path::new(&test_file2.0).exists()); @@ -194,7 +194,7 @@ mod tests { #[test] fn memory_storage() -> anyhow::Result<()> { - let mut storage = ServerDbStorage::new("memory:db_test.agdb")?; + let mut storage = UserDbStorage::new("memory:db_test.agdb")?; let _ = format!("{:?}", storage); storage.backup("backup_test")?; let other = storage.copy("db_test_copy.agdb")?; @@ -213,7 +213,7 @@ mod tests { #[test] fn invalid_db_name() { assert_eq!( - ServerDbStorage::new("db.agdb").unwrap_err().description, + UserDbStorage::new("db.agdb").unwrap_err().description, "Invalid server database name format, must be 'type:name'. Allowed types: mapped, memory, file." ); } @@ -221,7 +221,7 @@ mod tests { #[test] fn invalid_db_type() { assert_eq!( - ServerDbStorage::new("sometype:db.agdb") + UserDbStorage::new("sometype:db.agdb") .unwrap_err() .description, "Invalid db type 'sometype', must be one of 'mapped', 'memory', 'file'." From 10f82d0e533da63171dac664506c60e784e6b6cf Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Sun, 10 Nov 2024 13:25:43 +0100 Subject: [PATCH 2/3] extract functionality --- agdb_server/src/db_pool.rs | 327 +++------------------------ agdb_server/src/db_pool/server_db.rs | 24 ++ agdb_server/src/db_pool/user_db.rs | 309 ++++++++++++++++++++++++- 3 files changed, 354 insertions(+), 306 deletions(-) create mode 100644 agdb_server/src/db_pool/server_db.rs diff --git a/agdb_server/src/db_pool.rs b/agdb_server/src/db_pool.rs index 0226903e..43f9b5d8 100644 --- a/agdb_server/src/db_pool.rs +++ b/agdb_server/src/db_pool.rs @@ -1,8 +1,9 @@ +mod server_db; mod user_db; mod user_db_storage; use crate::config::Config; -use crate::db_pool::user_db_storage::UserDbStorage; +use crate::db_pool::server_db::ServerDb; use crate::error_code::ErrorCode; use crate::password; use crate::password::Password; @@ -14,15 +15,10 @@ use agdb::CountComparison; use agdb::DbId; use agdb::DbUserValue; use agdb::QueryBuilder; -use agdb::QueryConditionData; -use agdb::QueryError; use agdb::QueryId; -use agdb::QueryIds; use agdb::QueryResult; use agdb::QueryType; use agdb::SearchQuery; -use agdb::Transaction; -use agdb::TransactionMut; use agdb::UserValue; use agdb_api::AdminStatus; use agdb_api::DbAudit; @@ -31,7 +27,6 @@ use agdb_api::DbType; use agdb_api::DbUser; use agdb_api::DbUserRole; use agdb_api::Queries; -use agdb_api::QueryAudit; use agdb_api::ServerDatabase; use agdb_api::UserStatus; use axum::http::StatusCode; @@ -47,8 +42,8 @@ use std::time::UNIX_EPOCH; use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; use tokio::sync::RwLockWriteGuard; -use user_db::ServerDbImpl; use user_db::UserDb; +use user_db::UserDbImpl; use uuid::Uuid; #[derive(UserValue)] @@ -69,7 +64,7 @@ struct Database { } pub(crate) struct DbPoolImpl { - server_db: UserDb, + server_db: ServerDb, pool: RwLock>, } @@ -83,7 +78,7 @@ impl DbPool { .join("agdb_server.agdb") .exists(); let db_pool = Self(Arc::new(DbPoolImpl { - server_db: UserDb::new(&format!("mapped:{}/agdb_server.agdb", config.data_dir))?, + server_db: ServerDb::new(&format!("mapped:{}/agdb_server.agdb", config.data_dir))?, pool: RwLock::new(HashMap::new()), })); @@ -425,9 +420,8 @@ impl DbPool { .await .get(&database.name) .ok_or(db_not_found(&database.name))? - .get() - .await - .size(); + .size() + .await; Ok(ServerDatabase { name: database.name, @@ -509,9 +503,9 @@ impl DbPool { std::fs::create_dir_all(db_backup_dir(owner, config))?; } server_db - .get_mut() - .await - .backup(backup_path.to_string_lossy().as_ref())?; + .backup(backup_path.to_string_lossy().as_ref()) + .await?; + database.backup = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); self.save_db(database).await?; Ok(()) @@ -662,7 +656,7 @@ impl DbPool { owner: &str, db: &str, user: DbId, - mut queries: Queries, + queries: Queries, config: &Config, ) -> ServerResult> { let db_name = db_name(owner, db); @@ -678,42 +672,20 @@ impl DbPool { .await .get(&db_name) .ok_or(db_not_found(&db_name))? - .get() + .exec(queries) .await - .transaction(|t| { - let mut results = vec![]; - - for q in queries.0.iter_mut() { - let result = t_exec(t, q, &results)?; - results.push(result); - } - - Ok(results) - }) } else { let username = self.user_name(user).await?; - let mut audit = vec![]; - let r = self + let (r, audit) = self .get_pool() .await .get(&db_name) .ok_or(db_not_found(&db_name))? - .get_mut() - .await - .transaction_mut(|t| { - let mut results = vec![]; - let mut qs = vec![]; - std::mem::swap(&mut queries.0, &mut qs); - - for q in qs { - let result = t_exec_mut(t, q, &results, &mut audit, &username)?; - results.push(result); - } + .exec_mut(queries, &username) + .await?; - Ok(results) - }); - if r.is_ok() && !audit.is_empty() { + if !audit.is_empty() { let mut log = std::fs::OpenOptions::new() .create(true) .truncate(false) @@ -729,9 +701,9 @@ impl DbPool { log.write_all(&data)?; } } - r - } - .map_err(|e: QueryError| ServerError::new(ErrorCode::QueryError.into(), &e.description))?; + + Ok(r) + }?; Ok(results) } @@ -764,9 +736,8 @@ impl DbPool { size: pool .get(&db.name) .ok_or(db_not_found(&db.name))? - .get() - .await - .size(), + .size() + .await, backup: db.backup, name: db.name, }); @@ -839,9 +810,8 @@ impl DbPool { .await .get(&db.name) .ok_or(db_not_found(&db.name))? - .get() - .await - .size(); + .size() + .await; server_db.name = db.name; } } @@ -958,8 +928,8 @@ impl DbPool { let pool = self.get_pool().await; let server_db = pool.get(&db.name).ok_or(db_not_found(&db.name))?; - server_db.get_mut().await.optimize_storage()?; - let size = server_db.get().await.size(); + server_db.optimize_storage().await?; + let size = server_db.size().await; Ok(ServerDatabase { name: db.name, @@ -1094,14 +1064,11 @@ impl DbPool { ); server_db - .get_mut() - .await .rename(target_name.to_string_lossy().as_ref()) - .map_err(|e| { - ServerError::new( - ErrorCode::DbInvalid.into(), - &format!("db rename error: {}", e.description), - ) + .await + .map_err(|mut e| { + e.status = ErrorCode::DbInvalid.into(); + e })?; let backup_path = db_backup_file(owner, db, config); @@ -1268,11 +1235,11 @@ impl DbPool { .to_string()) } - async fn db(&self) -> RwLockReadGuard { + async fn db(&self) -> RwLockReadGuard { self.0.server_db.get().await } - async fn db_mut(&self) -> RwLockWriteGuard { + async fn db_mut(&self) -> RwLockWriteGuard { self.0.server_db.get_mut().await } @@ -1503,235 +1470,3 @@ fn required_role(queries: &Queries) -> DbUserRole { DbUserRole::Read } - -fn t_exec( - t: &Transaction, - q: &mut QueryType, - results: &[QueryResult], -) -> Result { - match q { - QueryType::Search(q) => { - inject_results_search(q, results)?; - t.exec(&*q) - } - QueryType::SelectAliases(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectAllAliases(q) => t.exec(&*q), - QueryType::SelectEdgeCount(q) => t.exec(&*q), - QueryType::SelectIndexes(q) => t.exec(&*q), - QueryType::SelectKeys(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectKeyCount(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectNodeCount(q) => t.exec(&*q), - QueryType::SelectValues(q) => { - inject_results(&mut q.ids, results)?; - t.exec(&*q) - } - _ => unreachable!(), - } -} - -fn t_exec_mut( - t: &mut TransactionMut, - mut q: QueryType, - results: &[QueryResult], - audit: &mut Vec, - username: &str, -) -> Result { - let mut do_audit = false; - - let r = match &mut q { - QueryType::Search(q) => { - inject_results_search(q, results)?; - t.exec(&*q) - } - QueryType::SelectAliases(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectAllAliases(q) => t.exec(&*q), - QueryType::SelectEdgeCount(q) => t.exec(&*q), - QueryType::SelectIndexes(q) => t.exec(&*q), - QueryType::SelectKeys(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectKeyCount(q) => { - inject_results(&mut q.0, results)?; - t.exec(&*q) - } - QueryType::SelectNodeCount(q) => t.exec(&*q), - QueryType::SelectValues(q) => { - inject_results(&mut q.ids, results)?; - t.exec(&*q) - } - QueryType::InsertAlias(q) => { - do_audit = true; - inject_results(&mut q.ids, results)?; - t.exec_mut(&*q) - } - QueryType::InsertEdges(q) => { - do_audit = true; - inject_results(&mut q.ids, results)?; - inject_results(&mut q.from, results)?; - inject_results(&mut q.to, results)?; - - t.exec_mut(&*q) - } - QueryType::InsertNodes(q) => { - do_audit = true; - inject_results(&mut q.ids, results)?; - t.exec_mut(&*q) - } - QueryType::InsertValues(q) => { - do_audit = true; - inject_results(&mut q.ids, results)?; - t.exec_mut(&*q) - } - QueryType::Remove(q) => { - do_audit = true; - inject_results(&mut q.0, results)?; - t.exec_mut(&*q) - } - QueryType::InsertIndex(q) => { - do_audit = true; - t.exec_mut(&*q) - } - QueryType::RemoveAliases(q) => { - do_audit = true; - t.exec_mut(&*q) - } - QueryType::RemoveIndex(q) => { - do_audit = true; - t.exec_mut(&*q) - } - QueryType::RemoveValues(q) => { - do_audit = true; - inject_results(&mut q.0.ids, results)?; - t.exec_mut(&*q) - } - }; - - if do_audit { - audit_query(username, audit, q); - } - - r -} - -fn id_or_result(id: QueryId, results: &[QueryResult]) -> Result { - if let QueryId::Alias(alias) = &id { - if let Some(index) = alias.strip_prefix(':') { - if let Ok(index) = index.parse::() { - return Ok(QueryId::Id( - results - .get(index) - .ok_or(QueryError::from(format!( - "Results index out of bounds '{index}' (> {})", - results.len() - )))? - .elements - .first() - .ok_or(QueryError::from("No element found in the result"))? - .id, - )); - } - } - } - - Ok(id) -} - -fn inject_results(ids: &mut QueryIds, results: &[QueryResult]) -> Result<(), QueryError> { - match ids { - QueryIds::Ids(ids) => { - inject_results_ids(ids, results)?; - } - QueryIds::Search(search) => { - inject_results_search(search, results)?; - } - } - - Ok(()) -} - -fn inject_results_search( - search: &mut SearchQuery, - results: &[QueryResult], -) -> Result<(), QueryError> { - search.origin = id_or_result(search.origin.clone(), results)?; - search.destination = id_or_result(search.destination.clone(), results)?; - - for c in &mut search.conditions { - if let QueryConditionData::Ids(ids) = &mut c.data { - inject_results_ids(ids, results)?; - } - } - - Ok(()) -} - -fn inject_results_ids(ids: &mut Vec, results: &[QueryResult]) -> Result<(), QueryError> { - for i in 0..ids.len() { - if let QueryId::Alias(alias) = &ids[i] { - if let Some(index) = alias.strip_prefix(':') { - if let Ok(index) = index.parse::() { - let result_ids = results - .get(index) - .ok_or(QueryError::from(format!( - "Results index out of bounds '{index}' (> {})", - results.len() - )))? - .ids() - .into_iter() - .map(QueryId::Id) - .collect::>(); - ids.splice(i..i + 1, result_ids.into_iter()); - } - } - } - } - - Ok(()) -} - -fn audit_query(user: &str, audit: &mut Vec, query: QueryType) { - audit.push(QueryAudit { - timestamp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - user: user.to_string(), - query, - }); -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::db_pool::user_db::UserDb; - use agdb::QueryBuilder; - - #[tokio::test] - #[should_panic] - async fn unreachable() { - let db = UserDb::new("memory:test").unwrap(); - db.get() - .await - .transaction(|t| { - t_exec( - t, - &mut QueryType::Remove(QueryBuilder::remove().ids(1).query()), - &[], - ) - }) - .unwrap(); - } -} diff --git a/agdb_server/src/db_pool/server_db.rs b/agdb_server/src/db_pool/server_db.rs new file mode 100644 index 00000000..8203490a --- /dev/null +++ b/agdb_server/src/db_pool/server_db.rs @@ -0,0 +1,24 @@ +use crate::db_pool::user_db_storage::UserDbStorage; +use crate::server_error::ServerResult; +use agdb::DbImpl; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio::sync::RwLockReadGuard; +use tokio::sync::RwLockWriteGuard; + +pub(crate) type ServerDbImpl = DbImpl; +pub(crate) struct ServerDb(pub(crate) Arc>); + +impl ServerDb { + pub(crate) fn new(name: &str) -> ServerResult { + Ok(Self(Arc::new(RwLock::new(ServerDbImpl::new(name)?)))) + } + + pub(crate) async fn get(&self) -> RwLockReadGuard { + self.0.read().await + } + + pub(crate) async fn get_mut(&self) -> RwLockWriteGuard { + self.0.write().await + } +} diff --git a/agdb_server/src/db_pool/user_db.rs b/agdb_server/src/db_pool/user_db.rs index 554fc6c6..9870758a 100644 --- a/agdb_server/src/db_pool/user_db.rs +++ b/agdb_server/src/db_pool/user_db.rs @@ -1,28 +1,317 @@ use crate::db_pool::user_db_storage::UserDbStorage; +use crate::db_pool::ErrorCode; +use crate::db_pool::ServerError; use crate::server_error::ServerResult; use agdb::DbImpl; +use agdb::QueryConditionData; +use agdb::QueryError; +use agdb::QueryId; +use agdb::QueryIds; +use agdb::QueryResult; +use agdb::QueryType; +use agdb::SearchQuery; +use agdb::Transaction; +use agdb::TransactionMut; +use agdb_api::Queries; +use agdb_api::QueryAudit; use std::sync::Arc; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; use tokio::sync::RwLock; -use tokio::sync::RwLockReadGuard; -use tokio::sync::RwLockWriteGuard; -pub(crate) type ServerDbImpl = DbImpl; -pub(crate) struct UserDb(pub(crate) Arc>); +pub(crate) type UserDbImpl = DbImpl; +pub(crate) struct UserDb(pub(crate) Arc>); impl UserDb { pub(crate) fn new(name: &str) -> ServerResult { - Ok(Self(Arc::new(RwLock::new(ServerDbImpl::new(name)?)))) + Ok(Self(Arc::new(RwLock::new(UserDbImpl::new(name)?)))) + } + + pub(crate) async fn backup(&self, name: &str) -> ServerResult<()> { + self.0.read().await.backup(name)?; + Ok(()) } pub(crate) async fn copy(&self, name: &str) -> ServerResult { - Ok(Self(Arc::new(RwLock::new(self.get().await.copy(name)?)))) + Ok(Self(Arc::new(RwLock::new(self.0.read().await.copy(name)?)))) + } + + pub(crate) async fn exec(&self, mut queries: Queries) -> ServerResult> { + self.0.read().await.transaction(|t| { + let mut results = vec![]; + + for q in queries.0.iter_mut() { + let result = t_exec(t, q, &results)?; + results.push(result); + } + + Ok(results) + }) + } + + pub(crate) async fn exec_mut( + &self, + mut queries: Queries, + username: &str, + ) -> ServerResult<(Vec, Vec)> { + self.0.write().await.transaction_mut(|t| { + let mut audit = vec![]; + let mut results = vec![]; + let mut qs = vec![]; + std::mem::swap(&mut queries.0, &mut qs); + + for q in qs { + let result = t_exec_mut(t, q, &results, &mut audit, username)?; + results.push(result); + } + + Ok((results, audit)) + }) + } + + pub(crate) async fn optimize_storage(&self) -> ServerResult<()> { + self.0.write().await.optimize_storage()?; + Ok(()) + } + + pub(crate) async fn rename(&self, target_name: &str) -> ServerResult<()> { + self.0.write().await.rename(target_name)?; + Ok(()) + } + + pub(crate) async fn size(&self) -> u64 { + self.0.read().await.size() + } +} + +fn t_exec( + t: &Transaction, + q: &mut QueryType, + results: &[QueryResult], +) -> ServerResult { + match q { + QueryType::Search(q) => { + inject_results_search(q, results)?; + t.exec(&*q) + } + QueryType::SelectAliases(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectAllAliases(q) => t.exec(&*q), + QueryType::SelectEdgeCount(q) => t.exec(&*q), + QueryType::SelectIndexes(q) => t.exec(&*q), + QueryType::SelectKeys(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectKeyCount(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectNodeCount(q) => t.exec(&*q), + QueryType::SelectValues(q) => { + inject_results(&mut q.ids, results)?; + t.exec(&*q) + } + _ => unreachable!(), + } + .map_err(|e| ServerError::new(ErrorCode::QueryError.into(), &e.description)) +} + +fn audit_query(user: &str, audit: &mut Vec, query: QueryType) { + audit.push(QueryAudit { + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + user: user.to_string(), + query, + }); +} + +fn t_exec_mut( + t: &mut TransactionMut, + mut q: QueryType, + results: &[QueryResult], + audit: &mut Vec, + username: &str, +) -> ServerResult { + let mut do_audit = false; + + let r = match &mut q { + QueryType::Search(q) => { + inject_results_search(q, results)?; + t.exec(&*q) + } + QueryType::SelectAliases(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectAllAliases(q) => t.exec(&*q), + QueryType::SelectEdgeCount(q) => t.exec(&*q), + QueryType::SelectIndexes(q) => t.exec(&*q), + QueryType::SelectKeys(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectKeyCount(q) => { + inject_results(&mut q.0, results)?; + t.exec(&*q) + } + QueryType::SelectNodeCount(q) => t.exec(&*q), + QueryType::SelectValues(q) => { + inject_results(&mut q.ids, results)?; + t.exec(&*q) + } + QueryType::InsertAlias(q) => { + do_audit = true; + inject_results(&mut q.ids, results)?; + t.exec_mut(&*q) + } + QueryType::InsertEdges(q) => { + do_audit = true; + inject_results(&mut q.ids, results)?; + inject_results(&mut q.from, results)?; + inject_results(&mut q.to, results)?; + + t.exec_mut(&*q) + } + QueryType::InsertNodes(q) => { + do_audit = true; + inject_results(&mut q.ids, results)?; + t.exec_mut(&*q) + } + QueryType::InsertValues(q) => { + do_audit = true; + inject_results(&mut q.ids, results)?; + t.exec_mut(&*q) + } + QueryType::Remove(q) => { + do_audit = true; + inject_results(&mut q.0, results)?; + t.exec_mut(&*q) + } + QueryType::InsertIndex(q) => { + do_audit = true; + t.exec_mut(&*q) + } + QueryType::RemoveAliases(q) => { + do_audit = true; + t.exec_mut(&*q) + } + QueryType::RemoveIndex(q) => { + do_audit = true; + t.exec_mut(&*q) + } + QueryType::RemoveValues(q) => { + do_audit = true; + inject_results(&mut q.0.ids, results)?; + t.exec_mut(&*q) + } + }; + + if do_audit { + audit_query(username, audit, q); + } + + r.map_err(|e| ServerError::new(ErrorCode::QueryError.into(), &e.description)) +} + +fn id_or_result(id: QueryId, results: &[QueryResult]) -> ServerResult { + if let QueryId::Alias(alias) = &id { + if let Some(index) = alias.strip_prefix(':') { + if let Ok(index) = index.parse::() { + return Ok(QueryId::Id( + results + .get(index) + .ok_or(ServerError::new( + ErrorCode::QueryError.into(), + &format!( + "Results index out of bounds '{index}' (> {})", + results.len() + ), + ))? + .elements + .first() + .ok_or(ServerError::new( + ErrorCode::QueryError.into(), + "No element found in the result", + ))? + .id, + )); + } + } + } + + Ok(id) +} + +fn inject_results(ids: &mut QueryIds, results: &[QueryResult]) -> ServerResult<()> { + match ids { + QueryIds::Ids(ids) => inject_results_ids(ids, results), + QueryIds::Search(search) => inject_results_search(search, results), } + .map_err(|mut e| { + e.status = ErrorCode::QueryError.into(); + e + }) +} - pub(crate) async fn get(&self) -> RwLockReadGuard { - self.0.read().await +fn inject_results_search(search: &mut SearchQuery, results: &[QueryResult]) -> ServerResult<()> { + search.origin = id_or_result(search.origin.clone(), results)?; + search.destination = id_or_result(search.destination.clone(), results)?; + + for c in &mut search.conditions { + if let QueryConditionData::Ids(ids) = &mut c.data { + inject_results_ids(ids, results)?; + } } - pub(crate) async fn get_mut(&self) -> RwLockWriteGuard { - self.0.write().await + Ok(()) +} + +fn inject_results_ids(ids: &mut Vec, results: &[QueryResult]) -> ServerResult<()> { + for i in 0..ids.len() { + if let QueryId::Alias(alias) = &ids[i] { + if let Some(index) = alias.strip_prefix(':') { + if let Ok(index) = index.parse::() { + let result_ids = results + .get(index) + .ok_or(ServerError::new( + ErrorCode::QueryError.into(), + &format!( + "Results index out of bounds '{index}' (> {})", + results.len() + ), + ))? + .ids() + .into_iter() + .map(QueryId::Id) + .collect::>(); + ids.splice(i..i + 1, result_ids.into_iter()); + } + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db_pool::user_db::UserDb; + use agdb::QueryBuilder; + + #[tokio::test] + #[should_panic] + async fn unreachable() { + let db = UserDb::new("memory:test").unwrap(); + db.exec(Queries(vec![QueryType::Remove( + QueryBuilder::remove().ids(1).query(), + )])) + .await + .unwrap(); } } From e2180c355be45630040d3bb01d17b530bcabf65c Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Sun, 10 Nov 2024 13:26:38 +0100 Subject: [PATCH 3/3] Update user_db.rs --- agdb_server/src/db_pool/user_db.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/agdb_server/src/db_pool/user_db.rs b/agdb_server/src/db_pool/user_db.rs index 9870758a..071d7847 100644 --- a/agdb_server/src/db_pool/user_db.rs +++ b/agdb_server/src/db_pool/user_db.rs @@ -4,7 +4,6 @@ use crate::db_pool::ServerError; use crate::server_error::ServerResult; use agdb::DbImpl; use agdb::QueryConditionData; -use agdb::QueryError; use agdb::QueryId; use agdb::QueryIds; use agdb::QueryResult;