From c3a1590cf570b9a1c5512bca12634e7d72841d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 23 Jan 2025 11:24:08 +0100 Subject: [PATCH 1/2] service: Move aux files when moving repos --- lib/src/lib.rs | 2 +- lib/src/repository/mod.rs | 14 +- service/src/state.rs | 118 ++++--------- service/src/state/move_repository.rs | 244 +++++++++++++++++++++++++++ service/src/state/tests.rs | 19 ++- 5 files changed, 296 insertions(+), 101 deletions(-) create mode 100644 service/src/state/move_repository.rs diff --git a/lib/src/lib.rs b/lib/src/lib.rs index ae999f3c..c5c972b8 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -61,7 +61,7 @@ pub use self::{ progress::Progress, protocol::{RepositoryId, StorageSize, BLOCK_SIZE}, repository::{ - repository_files, Credentials, Metadata, Repository, RepositoryHandle, RepositoryParams, + database_files, Credentials, Metadata, Repository, RepositoryHandle, RepositoryParams, }, store::{Error as StoreError, DATA_VERSION}, version_vector::VersionVector, diff --git a/lib/src/repository/mod.rs b/lib/src/repository/mod.rs index 5e44b2bc..905ec2a8 100644 --- a/lib/src/repository/mod.rs +++ b/lib/src/repository/mod.rs @@ -64,9 +64,17 @@ pub struct Repository { progress_reporter_handle: BlockingMutex>>, } -/// Given a path to the main repository file, return paths to all the repository files (the main db -/// file and all auxiliary files). -pub fn repository_files(store_path: impl AsRef) -> Vec { +/// List of repository database files. Includes the main db file and any auxiliary files. +/// +/// The aux files don't always exists but when they do and one wants to rename, move or delete the +/// repository, they should rename/move/delete these files as well. +/// +/// Note ideally the aux files should be deleted automatically when the repo has been closed. But +/// because of a [bug][1] in sqlx, they are sometimes not and need to be handled +/// (move,deleted) manually. When the bug is fixed, this function can be removed. +/// +/// [1]: https://github.com/launchbadge/sqlx/issues/3217 +pub fn database_files(store_path: impl AsRef) -> Vec { // Sqlite database consists of up to three files: main db (always present), WAL and WAL-index. ["", "-wal", "-shm"] .into_iter() diff --git a/service/src/state.rs b/service/src/state.rs index 28b12249..00a718d4 100644 --- a/service/src/state.rs +++ b/service/src/state.rs @@ -1,3 +1,7 @@ +mod move_repository; +#[cfg(test)] +mod tests; + use crate::{ config_keys::{ BIND_KEY, DEFAULT_BLOCK_EXPIRATION_MILLIS, DEFAULT_QUOTA_KEY, @@ -39,9 +43,6 @@ use tokio::{ use tokio_rustls::rustls; use tokio_stream::StreamExt; -#[cfg(test)] -mod tests; - const AUTOMOUNT_KEY: &str = "automount"; const DHT_ENABLED_KEY: &str = "dht_enabled"; const PEX_ENABLED_KEY: &str = "pex_enabled"; @@ -51,7 +52,7 @@ const REPOSITORY_FILE_EXTENSION: &str = "ouisyncdb"; pub(crate) struct State { pub config: ConfigStore, pub network: Network, - store_dir: Option, + store: Store, mounter: Option, repos: RepositorySet, files: FileSet, @@ -79,6 +80,8 @@ impl State { Err(error) => return Err(error.into()), }; + let store = Store { dir: store_dir }; + let mount_dir = match config.entry(MOUNT_DIR_KEY).get().await { Ok(dir) => Some(dir), Err(ConfigError::NotFound) => None, @@ -96,7 +99,7 @@ impl State { let mut state = Self { config, network, - store_dir, + store, mounter, root_monitor, repos_monitor, @@ -224,16 +227,16 @@ impl State { } pub fn store_dir(&self) -> Option<&Path> { - self.store_dir.as_deref() + self.store.dir.as_deref() } pub async fn set_store_dir(&mut self, dir: PathBuf) -> Result<(), Error> { - if Some(dir.as_path()) == self.store_dir.as_deref() { + if Some(dir.as_path()) == self.store.dir.as_deref() { return Ok(()); } self.config.entry(STORE_DIR_KEY).set(&dir).await?; - self.store_dir = Some(dir); + self.store.dir = Some(dir); // Close repos from the previous store dir and load repos from the new dir. self.close_repositories().await; @@ -300,7 +303,7 @@ impl State { dht_enabled: bool, pex_enabled: bool, ) -> Result { - let path = self.normalize_repository_path(path)?; + let path = self.store.normalize_repository_path(path)?; if self.repos.find_by_path(&path).is_some() { Err(Error::AlreadyExists)?; @@ -377,7 +380,7 @@ impl State { holder.close().await?; - for path in ouisync::repository_files(holder.path()) { + for path in ouisync::database_files(holder.path()) { match fs::remove_file(path).await { Ok(()) => (), Err(error) if error.kind() == io::ErrorKind::NotFound => (), @@ -385,7 +388,7 @@ impl State { } } - self.remove_empty_ancestor_dirs(holder.path()).await?; + self.store.remove_empty_ancestor_dirs(holder.path()).await?; tracing::info!(name = holder.short_name(), "repository deleted"); @@ -397,7 +400,7 @@ impl State { path: &Path, local_secret: Option, ) -> Result { - let path = self.normalize_repository_path(path)?; + let path = self.store.normalize_repository_path(path)?; let handle = if let Some((handle, holder)) = self.repos.find_by_path(&path) { // If `local_secret` provides higher access mode than what the repo currently has, // increase it. If not, the access mode remains unchanged. @@ -494,57 +497,7 @@ impl State { handle: RepositoryHandle, dst: &Path, ) -> Result<(), Error> { - // This function is "best effort atomic". - - let dst = self.normalize_repository_path(dst)?; - let dst_parent = dst.parent().ok_or(Error::InvalidArgument)?; - - if self.repos.find_by_path(&dst).is_some() { - return Err(Error::AlreadyExists); - } - - let holder = self.repos.get_mut(handle).ok_or(Error::InvalidArgument)?; - // Preserve access mode after the move - let credentials = holder.repository().credentials(); - let sync_enabled = holder.registration().is_some(); - - // TODO: close all open files of this repo - - if let Some(mounter) = &self.mounter { - mounter.remove(holder.short_name())?; - } - - fs::create_dir_all(dst_parent).await?; - - if let Err(error) = holder.close().await { - // Try to revert creating the dst parent directory but if it fails still return the - // close error because it's more important. - self.remove_empty_ancestor_dirs(dst_parent).await.ok(); - return Err(error); - } - - let src = holder.path().to_owned(); - - let (old_path, new_path, result) = match move_file(&src, &dst).await { - Ok(()) => (src, dst, Ok(())), - Err(error) => (dst, src, Err(error.into())), // Restore the original repo - }; - - *holder = load_repository( - &new_path, - None, - sync_enabled, - &self.config, - &self.network, - &self.repos_monitor, - self.mounter.as_ref(), - ) - .await?; - holder.repository().set_credentials(credentials).await?; - - self.remove_empty_ancestor_dirs(&old_path).await?; - - result + move_repository::invoke(self, handle, dst).await } pub async fn reset_repository_access( @@ -1386,7 +1339,7 @@ impl State { // Find all repositories in the store dir and open them. async fn load_repositories(&mut self) { - let Some(store_dir) = self.store_dir.as_deref() else { + let Some(store_dir) = self.store.dir.as_deref() else { tracing::warn!("store dir not specified"); return; }; @@ -1472,12 +1425,22 @@ impl State { .await } + async fn connect_remote_client(&self, host: &str) -> Result { + Ok(RemoteClient::connect(host, self.remote_client_config().await?).await?) + } +} + +struct Store { + dir: Option, +} + +impl Store { fn normalize_repository_path(&self, path: &Path) -> Result { let path = if path.is_absolute() { Cow::Borrowed(path) } else { Cow::Owned( - self.store_dir + self.dir .as_deref() .ok_or(Error::StoreDirUnspecified)? .join(path), @@ -1499,7 +1462,7 @@ impl State { // Remove ancestors directories up to `store_dir` but only if they are empty. async fn remove_empty_ancestor_dirs(&self, path: &Path) -> Result<(), io::Error> { - let Some(store_dir) = &self.store_dir else { + let Some(store_dir) = &self.dir else { return Ok(()); }; @@ -1521,10 +1484,6 @@ impl State { Ok(()) } - - async fn connect_remote_client(&self, host: &str) -> Result { - Ok(RemoteClient::connect(host, self.remote_client_config().await?).await?) - } } async fn load_repository( @@ -1596,25 +1555,6 @@ async fn set_metadata_bool(repo: &Repository, key: &str, value: bool) -> Result< Ok(()) } -/// Move file from `src` to `dst`. If they are on the same filesystem, it does a simple rename. -/// Otherwise it copies `src` to `dst` first and then deletes `src`. -async fn move_file(src: &Path, dst: &Path) -> io::Result<()> { - // First try rename - match fs::rename(src, dst).await { - Ok(()) => return Ok(()), - Err(_error) => { - // TODO: we should only fallback on `io::ErrorKind::CrossesDevices` but that variant is - // currently unstable. - } - } - - // If that didn't work, fallback to copy + remove - fs::copy(src, dst).await?; - fs::remove_file(src).await?; - - Ok(()) -} - async fn set_default_quota( config: &ConfigStore, value: Option, diff --git a/service/src/state/move_repository.rs b/service/src/state/move_repository.rs new file mode 100644 index 00000000..ab5329cc --- /dev/null +++ b/service/src/state/move_repository.rs @@ -0,0 +1,244 @@ +use std::{ + fmt, io, + path::{Path, PathBuf}, +}; + +use crate::{protocol::RepositoryHandle, Error}; +use ouisync::{Credentials, Network}; +use ouisync_vfs::{MultiRepoMount, MultiRepoVFS}; +use state_monitor::StateMonitor; +use tokio::fs; + +use super::{load_repository, ConfigStore, RepositoryHolder, State, Store}; + +/// Move or rename a repository. Makes "best effort" to do it atomically, that is, if any step of +/// this operation fails, tries to revert all previous steps before returning. +pub(super) async fn invoke( + state: &mut State, + handle: RepositoryHandle, + dst: &Path, +) -> Result<(), Error> { + let mut context = Context::new(state, handle, dst)?; + let mut undo_stack = Vec::new(); + + match context.invoke(&mut undo_stack).await { + Ok(()) => Ok(()), + Err(error) => { + context.undo(&mut undo_stack).await; + Err(error) + } + } +} + +struct Context<'a> { + config: &'a ConfigStore, + network: &'a Network, + store: &'a Store, + mounter: Option<&'a MultiRepoVFS>, + repos_monitor: &'a StateMonitor, + holder: &'a mut RepositoryHolder, + dst: PathBuf, +} + +impl<'a> Context<'a> { + fn new(state: &'a mut State, handle: RepositoryHandle, dst: &Path) -> Result { + let dst = state.store.normalize_repository_path(dst)?; + + if state.repos.find_by_path(&dst).is_some() { + return Err(Error::AlreadyExists); + } + + let holder = state.repos.get_mut(handle).ok_or(Error::InvalidArgument)?; + + Ok(Self { + config: &state.config, + network: &state.network, + store: &state.store, + mounter: state.mounter.as_ref(), + repos_monitor: &state.repos_monitor, + holder, + dst, + }) + } + + async fn invoke(&mut self, undo_stack: &mut Vec) -> Result<(), Error> { + // TODO: close all open files of this repo + + // 1. Unmount the repo (if mounted) + if let Some(mounter) = self.mounter { + mounter.remove(self.holder.short_name())?; + undo_stack.push(Action::Unmount); + } + + // 2. Create the dst directory + let dst_parent = self.dst.parent().ok_or(Error::InvalidArgument)?; + fs::create_dir_all(dst_parent).await?; + undo_stack.push(Action::CreateDir { + path: dst_parent.to_owned(), + }); + + // 3. Close the repo + let credentials = self.holder.repository().credentials(); + let sync_enabled = self.holder.registration().is_some(); + + self.holder.close().await?; + undo_stack.push(Action::CloseRepository { + credentials: credentials.clone(), + sync_enabled, + }); + + // 4. Move the database file(s) + for (src, dst) in ouisync::database_files(self.holder.path()) + .into_iter() + .zip(ouisync::database_files(&self.dst)) + { + if !fs::try_exists(&src).await? { + continue; + } + + move_file(&src, &dst).await?; + undo_stack.push(Action::MoveFile { src, dst }); + } + + // 5. Remove the old parent directory + let src_parent = self + .holder + .path() + .parent() + .ok_or(Error::InvalidArgument)? + .to_owned(); + self.store + .remove_empty_ancestor_dirs(self.holder.path()) + .await?; + undo_stack.push(Action::RemoveDir { path: src_parent }); + + // 6. Open the repository from its new location + *self.holder = self + .load_repository(&self.dst, credentials, sync_enabled) + .await?; + + Ok(()) + } + + async fn undo(&mut self, undo_stack: &mut Vec) { + while let Some(action) = undo_stack.pop() { + let action_debug = format!("{:?}", action); + action + .undo(self) + .await + .inspect_err(|error| tracing::error!(?error, "failed to undo {action_debug}")) + .ok(); + } + } + + async fn load_repository( + &self, + path: &Path, + credentials: Credentials, + sync_enabled: bool, + ) -> Result { + let holder = load_repository( + path, + None, + sync_enabled, + self.config, + self.network, + self.repos_monitor, + self.mounter, + ) + .await?; + holder.repository().set_credentials(credentials).await?; + + Ok(holder) + } +} + +#[expect(clippy::large_enum_variant)] +enum Action { + Unmount, + CreateDir { + path: PathBuf, + }, + CloseRepository { + credentials: Credentials, + sync_enabled: bool, + }, + MoveFile { + src: PathBuf, + dst: PathBuf, + }, + RemoveDir { + path: PathBuf, + }, +} + +impl Action { + async fn undo(self, context: &mut Context<'_>) -> Result<(), Error> { + match self { + Self::Unmount => { + if let Some(mounter) = context.mounter { + mounter.insert( + context.holder.short_name().to_owned(), + context.holder.repository().clone(), + )?; + } + } + Self::CreateDir { path } => { + context.store.remove_empty_ancestor_dirs(&path).await?; + } + Self::CloseRepository { + credentials, + sync_enabled, + } => { + *context.holder = context + .load_repository(context.holder.path(), credentials, sync_enabled) + .await?; + } + Self::MoveFile { src, dst } => { + move_file(&dst, &src).await?; + } + Self::RemoveDir { path } => { + fs::create_dir_all(path).await?; + } + } + + Ok(()) + } +} + +impl fmt::Debug for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unmount => f.debug_tuple("Unmount").finish(), + Self::CreateDir { path } => f.debug_struct("CreateDir").field("path", path).finish(), + Self::CloseRepository { .. } => { + f.debug_struct("CloseRepository").finish_non_exhaustive() + } + Self::MoveFile { src, dst } => f + .debug_struct("MoveFile") + .field("src", src) + .field("dst", dst) + .finish(), + Self::RemoveDir { path } => f.debug_struct("RemoveDir").field("path", path).finish(), + } + } +} + +/// Move file from `src` to `dst`. If they are on the same filesystem, it does a simple rename. +/// Otherwise it copies `src` to `dst` first and then deletes `src`. +async fn move_file(src: &Path, dst: &Path) -> io::Result<()> { + // First try rename + match fs::rename(src, dst).await { + Ok(()) => return Ok(()), + Err(_error) => { + // TODO: we should only fallback on `io::ErrorKind::CrossesDevices` but that variant is + // currently unstable. + } + } + + // If that didn't work, fallback to copy + remove + fs::copy(src, dst).await?; + fs::remove_file(src).await?; + + Ok(()) +} diff --git a/service/src/state/tests.rs b/service/src/state/tests.rs index c3ce2272..2910eb02 100644 --- a/service/src/state/tests.rs +++ b/service/src/state/tests.rs @@ -13,34 +13,35 @@ use tracing::Instrument; #[tokio::test] async fn normalize_repository_path() { - let (temp_dir, mut state) = setup().await; - let store_dir = temp_dir.path().join("store"); - state.set_store_dir(store_dir.clone()).await.unwrap(); + let store_dir = PathBuf::from("/home/alice/ouisync"); + let store = Store { + dir: Some(store_dir.clone()), + }; assert_eq!( - state.normalize_repository_path(Path::new("foo")).unwrap(), + store.normalize_repository_path(Path::new("foo")).unwrap(), store_dir.join("foo.ouisyncdb") ); assert_eq!( - state + store .normalize_repository_path(Path::new("foo/bar")) .unwrap(), store_dir.join("foo/bar.ouisyncdb") ); assert_eq!( - state + store .normalize_repository_path(Path::new("foo/bar.baz")) .unwrap(), store_dir.join("foo/bar.baz.ouisyncdb") ); assert_eq!( - state + store .normalize_repository_path(Path::new("foo.ouisyncdb")) .unwrap(), store_dir.join("foo.ouisyncdb") ); assert_eq!( - state + store .normalize_repository_path(Path::new("/home/alice/repos/foo")) .unwrap(), Path::new("/home/alice/repos/foo.ouisyncdb") @@ -295,6 +296,8 @@ async fn expire_synced_repository() { #[tokio::test] async fn move_repository() { + test_utils::init_log(); + let (_temp_dir, mut state) = setup().await; let src = Path::new("foo"); let dst = Path::new("bar"); From d0a78ccd89fea42dfa2a9581a8c48c0d88bb63a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 23 Jan 2025 15:47:42 +0100 Subject: [PATCH 2/2] service: change log callback signature for better portability --- bindings/dart/lib/bindings.dart | 23 ++------ bindings/dart/lib/server.dart | 8 +-- bindings/dart/test/ouisync_test.dart | 5 ++ .../org/equalitie/ouisync/lib/Bindings.kt | 8 ++- .../org/equalitie/ouisync/lib/Client.kt | 34 ++++++----- .../org/equalitie/ouisync/lib/Server.kt | 17 +++++- .../org/equalitie/ouisync/lib/Session.kt | 10 +++- .../org/equalitie/ouisync/SessionTest.kt | 4 +- service/src/ffi.rs | 57 ++++++------------- 9 files changed, 76 insertions(+), 90 deletions(-) diff --git a/bindings/dart/lib/bindings.dart b/bindings/dart/lib/bindings.dart index 84ed79bb..96e16498 100644 --- a/bindings/dart/lib/bindings.dart +++ b/bindings/dart/lib/bindings.dart @@ -6,24 +6,11 @@ import 'package:path/path.dart'; export 'bindings.g.dart'; -final class LogMessage extends Struct { - @Uint8() - external int level; - - external Pointer ptr; - - @Uint64() - external int len; - - @Uint64() - external int cap; -} - -/// Callback for `service_start` and `service_stop`. +/// Callback for `start_service` and `stop_service`. typedef StatusCallback = Void Function(Pointer, Uint16); -/// Callback for `log_init`. -typedef LogCallback = Void Function(LogMessage); +/// Callback for `init_log`. +typedef LogCallback = Void Function(Uint8, Pointer, Uint64, Uint64); /// typedef StartService = Pointer Function( @@ -56,8 +43,8 @@ typedef _InitLogC = Uint16 Function( Pointer>, ); -typedef ReleaseLogMessage = void Function(LogMessage); -typedef _ReleaseLogMessageC = Void Function(LogMessage); +typedef ReleaseLogMessage = void Function(Pointer, int, int); +typedef _ReleaseLogMessageC = Void Function(Pointer, Uint64, Uint64); class Bindings { Bindings(DynamicLibrary library) diff --git a/bindings/dart/lib/server.dart b/bindings/dart/lib/server.dart index 38eb1b05..f9dbd6f6 100644 --- a/bindings/dart/lib/server.dart +++ b/bindings/dart/lib/server.dart @@ -98,13 +98,13 @@ void initLog({ if (callback != null) { nativeCallback = NativeCallable.listener( - (LogMessage message) { + (int level, Pointer ptr, int len, int cap) { callback( - LogLevel.decode(message.level), - utf8.decode(message.ptr.asTypedList(message.len)), + LogLevel.decode(level), + utf8.decode(ptr.asTypedList(len)), ); - Bindings.instance.releaseLogMessage(message); + Bindings.instance.releaseLogMessage(ptr, len, cap); }, ); } diff --git a/bindings/dart/test/ouisync_test.dart b/bindings/dart/test/ouisync_test.dart index 0a356de5..e7c94894 100644 --- a/bindings/dart/test/ouisync_test.dart +++ b/bindings/dart/test/ouisync_test.dart @@ -8,6 +8,11 @@ void main() { late io.Directory temp; late Session session; + initLog(callback: (level, message) { + // ignore: avoid_print + print('${level.name.toUpperCase()} $message'); + }); + setUp(() async { temp = await io.Directory.systemTemp.createTemp(); session = await Session.create( diff --git a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Bindings.kt b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Bindings.kt index bcad7d6d..c0ba88ea 100644 --- a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Bindings.kt +++ b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Bindings.kt @@ -30,14 +30,16 @@ internal interface Bindings : Library { file: String?, callback: LogCallback?, ): Short + + fun release_log_message(ptr: Pointer, len: Long, cap: Long) } internal typealias Handle = Long -interface StatusCallback : JnaCallback { +internal interface StatusCallback : JnaCallback { fun invoke(context: Pointer?, error_code: Short) } -interface LogCallback : JnaCallback { - fun invoke(level: Byte, message: String) +internal interface LogCallback : JnaCallback { + fun invoke(level: Byte, ptr: Pointer, len: Long, cap: Long) } diff --git a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Client.kt b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Client.kt index 335a629c..af376cf7 100644 --- a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Client.kt +++ b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Client.kt @@ -64,15 +64,27 @@ internal class Client private constructor(private val socket: AsynchronousSocket } } - suspend fun invoke(request: Request): Any = - invoke(messageMatcher.nextId(), request) + suspend fun invoke(request: Request): Any { + val id = messageMatcher.nextId() + val deferred = CompletableDeferred() + messageMatcher.register(id, deferred) + + send(id, request) + + val response = deferred.await() + + when (response) { + is Success -> return response.value + is Failure -> throw response.error + } + } fun subscribe(request: Request): Flow = channelFlow { val id = messageMatcher.nextId() + messageMatcher.register(id, channel) try { - invoke(id, request) - messageMatcher.register(id, channel) + send(id, request) awaitClose() } finally { messageMatcher.deregister(id) @@ -92,20 +104,6 @@ internal class Client private constructor(private val socket: AsynchronousSocket socket.close() } - private suspend fun invoke(id: Long, request: Request): Any { - val deferred = CompletableDeferred() - messageMatcher.register(id, deferred) - - send(id, request) - - val response = deferred.await() - - when (response) { - is Success -> return response.value - is Failure -> throw response.error - } - } - private suspend fun send(id: Long, request: Request) { // Message format: // diff --git a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Server.kt b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Server.kt index 1a1fcee0..9e199f73 100644 --- a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Server.kt +++ b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Server.kt @@ -29,11 +29,15 @@ class Server private constructor(private val handle: Pointer) { typealias LogFunction = (level: LogLevel, message: String) -> Unit +// Need to keep the callback referenced to prevent it from being GC'd. +private var logHandler: LogHandler? = null + fun initLog( file: String? = null, callback: LogFunction? = null, ) { - bindings.init_log(file, callback?.let(::LogHandler)) + logHandler = logHandler ?: callback?.let(::LogHandler) + bindings.init_log(file, logHandler) } private class ResultHandler() : StatusCallback { @@ -53,7 +57,14 @@ private class ResultHandler() : StatusCallback { } private class LogHandler(val function: LogFunction) : LogCallback { - override fun invoke(level: Byte, message: String) { - function(LogLevel.decode(level), message) + override fun invoke(level: Byte, ptr: Pointer, len: Long, cap: Long) { + val level = LogLevel.decode(level) + val message = ptr.getByteArray(0, len.toInt()).decodeToString() + + try { + function(level, message) + } finally { + bindings.release_log_message(ptr, len, cap) + } } } diff --git a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Session.kt b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Session.kt index 79a499d1..848d9aaf 100644 --- a/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Session.kt +++ b/bindings/kotlin/lib/src/main/kotlin/org/equalitie/ouisync/lib/Session.kt @@ -2,6 +2,7 @@ package org.equalitie.ouisync.lib import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.mapNotNull /** * The entry point to the ouisync library. @@ -115,7 +116,14 @@ class Session private constructor( * Note the event subscription is created only after the flow starts being consumed. **/ fun subscribeToNetworkEvents(): Flow = - client.subscribe(NetworkSubscribe()).filterIsInstance() + client.subscribe(NetworkSubscribe()) + .mapNotNull { + when (it) { + is NetworkEvent -> it + is Any -> NetworkEvent.PEER_SET_CHANGE + else -> null + } + } /** * Returns the listener interface addresses. diff --git a/bindings/kotlin/lib/src/test/kotlin/org/equalitie/ouisync/SessionTest.kt b/bindings/kotlin/lib/src/test/kotlin/org/equalitie/ouisync/SessionTest.kt index 7150e845..f640d421 100644 --- a/bindings/kotlin/lib/src/test/kotlin/org/equalitie/ouisync/SessionTest.kt +++ b/bindings/kotlin/lib/src/test/kotlin/org/equalitie/ouisync/SessionTest.kt @@ -2,7 +2,6 @@ package org.equalitie.ouisync.lib import kotlinx.coroutines.channels.produce import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.yield import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -63,7 +62,8 @@ class SessionTest { session.subscribeToNetworkEvents().collect(::send) } - yield() + // Wait for the initial event indicating that the subscription has been created + assertEquals(NetworkEvent.PEER_SET_CHANGE, events.receive()) session.addUserProvidedPeers(listOf(addr)) assertEquals(NetworkEvent.PEER_SET_CHANGE, events.receive()) diff --git a/service/src/ffi.rs b/service/src/ffi.rs index 2162c2bc..09b69018 100644 --- a/service/src/ffi.rs +++ b/service/src/ffi.rs @@ -1,5 +1,5 @@ use std::{ - ffi::{c_char, c_void, CStr, CString}, + ffi::{c_char, c_uchar, c_ulong, c_void, CStr, CString}, io, mem, path::Path, pin::pin, @@ -166,6 +166,8 @@ fn init( Ok((runtime, service, span)) } +pub type LogCallback = extern "C" fn(LogLevel, *const c_uchar, c_ulong, c_ulong); + /// Initialize logging. Should be called before `service_start`. /// /// If `file` is not null, write log messages to the given file. @@ -178,10 +180,7 @@ fn init( /// /// `file` must be either null or it must be safe to pass to [std::ffi::CStr::from_ptr]. #[no_mangle] -pub unsafe extern "C" fn init_log( - file: *const c_char, - callback: Option, -) -> ErrorCode { +pub unsafe extern "C" fn init_log(file: *const c_char, callback: Option) -> ErrorCode { try_init_log(file, callback).to_error_code() } @@ -189,44 +188,17 @@ pub unsafe extern "C" fn init_log( /// /// # Safety /// -/// `message` must have been obtained through the callback to `init_log` and not modified. +/// `ptr`, `len` and `cap` must have been obtained through the callback to `init_log` and not +/// modified. #[no_mangle] -pub unsafe extern "C" fn release_log_message(message: LogMessage) { - let message = message.into_message(); +pub unsafe extern "C" fn release_log_message(ptr: *const c_uchar, len: c_ulong, cap: c_ulong) { + let message = Vec::from_raw_parts(ptr as _, len as _, cap as _); if let Some(pool) = LOGGER.get().and_then(|wrapper| wrapper.pool.as_ref()) { pool.release(message); } } -#[repr(C)] -pub struct LogMessage { - level: LogLevel, - ptr: *const u8, - len: usize, - cap: usize, -} - -impl LogMessage { - fn new(level: LogLevel, message: Vec) -> Self { - let ptr = message.as_ptr(); - let len = message.len(); - let cap = message.capacity(); - mem::forget(message); - - Self { - level, - ptr, - len, - cap, - } - } - - unsafe fn into_message(self) -> Vec { - Vec::from_raw_parts(self.ptr as _, self.len, self.cap) - } -} - struct LoggerWrapper { _logger: Logger, pool: Option, @@ -234,10 +206,7 @@ struct LoggerWrapper { static LOGGER: OnceLock = OnceLock::new(); -unsafe fn try_init_log( - file: *const c_char, - callback: Option, -) -> Result<(), Error> { +unsafe fn try_init_log(file: *const c_char, callback: Option) -> Result<(), Error> { let builder = Logger::builder(); let builder = if !file.is_null() { builder.file(Path::new(CStr::from_ptr(file).to_str()?)) @@ -248,7 +217,13 @@ unsafe fn try_init_log( let (builder, pool) = if let Some(callback) = callback { let pool = BufferPool::default(); let callback = Box::new(move |level, message: &mut Vec| { - callback(LogMessage::new(LogLevel::from(level), mem::take(message))); + let message = mem::take(message); + let ptr = message.as_ptr(); + let len = message.len(); + let cap = message.capacity(); + mem::forget(message); + + callback(LogLevel::from(level), ptr, len as _, cap as _); }); (builder.callback(callback, pool.clone()), Some(pool))