Skip to content

Commit

Permalink
Merge branch 'move-repo-fix' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
madadam committed Jan 23, 2025
2 parents 187ea25 + d0a78cc commit ba37d86
Show file tree
Hide file tree
Showing 14 changed files with 372 additions and 191 deletions.
23 changes: 5 additions & 18 deletions bindings/dart/lib/bindings.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,11 @@ import 'package:path/path.dart';

export 'bindings.g.dart';

final class LogMessage extends Struct {
@Uint8()
external int level;

external Pointer<Uint8> ptr;

@Uint64()
external int len;

@Uint64()
external int cap;
}

/// Callback for `service_start` and `service_stop`.
/// Callback for `start_service` and `stop_service`.
typedef StatusCallback = Void Function(Pointer<Void>, Uint16);

/// Callback for `log_init`.
typedef LogCallback = Void Function(LogMessage);
/// Callback for `init_log`.
typedef LogCallback = Void Function(Uint8, Pointer<Uint8>, Uint64, Uint64);

///
typedef StartService = Pointer<Void> Function(
Expand Down Expand Up @@ -56,8 +43,8 @@ typedef _InitLogC = Uint16 Function(
Pointer<NativeFunction<LogCallback>>,
);

typedef ReleaseLogMessage = void Function(LogMessage);
typedef _ReleaseLogMessageC = Void Function(LogMessage);
typedef ReleaseLogMessage = void Function(Pointer<Uint8>, int, int);
typedef _ReleaseLogMessageC = Void Function(Pointer<Uint8>, Uint64, Uint64);

class Bindings {
Bindings(DynamicLibrary library)
Expand Down
8 changes: 4 additions & 4 deletions bindings/dart/lib/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void initLog({

if (callback != null) {
nativeCallback = NativeCallable<LogCallback>.listener(
(LogMessage message) {
(int level, Pointer<Uint8> ptr, int len, int cap) {
callback(
LogLevel.decode(message.level),
utf8.decode(message.ptr.asTypedList(message.len)),
LogLevel.decode(level),
utf8.decode(ptr.asTypedList(len)),
);

Bindings.instance.releaseLogMessage(message);
Bindings.instance.releaseLogMessage(ptr, len, cap);
},
);
}
Expand Down
5 changes: 5 additions & 0 deletions bindings/dart/test/ouisync_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ void main() {
late io.Directory temp;
late Session session;

initLog(callback: (level, message) {
// ignore: avoid_print
print('${level.name.toUpperCase()} $message');
});

setUp(() async {
temp = await io.Directory.systemTemp.createTemp();
session = await Session.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ internal interface Bindings : Library {
file: String?,
callback: LogCallback?,
): Short

fun release_log_message(ptr: Pointer, len: Long, cap: Long)
}

internal typealias Handle = Long

interface StatusCallback : JnaCallback {
internal interface StatusCallback : JnaCallback {
fun invoke(context: Pointer?, error_code: Short)
}

interface LogCallback : JnaCallback {
fun invoke(level: Byte, message: String)
internal interface LogCallback : JnaCallback {
fun invoke(level: Byte, ptr: Pointer, len: Long, cap: Long)
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,27 @@ internal class Client private constructor(private val socket: AsynchronousSocket
}
}

suspend fun invoke(request: Request): Any =
invoke(messageMatcher.nextId(), request)
suspend fun invoke(request: Request): Any {
val id = messageMatcher.nextId()
val deferred = CompletableDeferred<Response>()
messageMatcher.register(id, deferred)

send(id, request)

val response = deferred.await()

when (response) {
is Success -> return response.value
is Failure -> throw response.error
}
}

fun subscribe(request: Request): Flow<Any> = channelFlow {
val id = messageMatcher.nextId()
messageMatcher.register(id, channel)

try {
invoke(id, request)
messageMatcher.register(id, channel)
send(id, request)
awaitClose()
} finally {
messageMatcher.deregister(id)
Expand All @@ -92,20 +104,6 @@ internal class Client private constructor(private val socket: AsynchronousSocket
socket.close()
}

private suspend fun invoke(id: Long, request: Request): Any {
val deferred = CompletableDeferred<Response>()
messageMatcher.register(id, deferred)

send(id, request)

val response = deferred.await()

when (response) {
is Success -> return response.value
is Failure -> throw response.error
}
}

private suspend fun send(id: Long, request: Request) {
// Message format:
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ class Server private constructor(private val handle: Pointer) {

typealias LogFunction = (level: LogLevel, message: String) -> Unit

// Need to keep the callback referenced to prevent it from being GC'd.
private var logHandler: LogHandler? = null

fun initLog(
file: String? = null,
callback: LogFunction? = null,
) {
bindings.init_log(file, callback?.let(::LogHandler))
logHandler = logHandler ?: callback?.let(::LogHandler)
bindings.init_log(file, logHandler)
}

private class ResultHandler() : StatusCallback {
Expand All @@ -53,7 +57,14 @@ private class ResultHandler() : StatusCallback {
}

private class LogHandler(val function: LogFunction) : LogCallback {
override fun invoke(level: Byte, message: String) {
function(LogLevel.decode(level), message)
override fun invoke(level: Byte, ptr: Pointer, len: Long, cap: Long) {
val level = LogLevel.decode(level)
val message = ptr.getByteArray(0, len.toInt()).decodeToString()

try {
function(level, message)
} finally {
bindings.release_log_message(ptr, len, cap)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.equalitie.ouisync.lib

import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.mapNotNull

/**
* The entry point to the ouisync library.
Expand Down Expand Up @@ -115,7 +116,14 @@ class Session private constructor(
* Note the event subscription is created only after the flow starts being consumed.
**/
fun subscribeToNetworkEvents(): Flow<NetworkEvent> =
client.subscribe(NetworkSubscribe()).filterIsInstance<NetworkEvent>()
client.subscribe(NetworkSubscribe())
.mapNotNull {
when (it) {
is NetworkEvent -> it
is Any -> NetworkEvent.PEER_SET_CHANGE
else -> null
}
}

/**
* Returns the listener interface addresses.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.equalitie.ouisync.lib

import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.yield
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
Expand Down Expand Up @@ -63,7 +62,8 @@ class SessionTest {
session.subscribeToNetworkEvents().collect(::send)
}

yield()
// Wait for the initial event indicating that the subscription has been created
assertEquals(NetworkEvent.PEER_SET_CHANGE, events.receive())

session.addUserProvidedPeers(listOf(addr))
assertEquals(NetworkEvent.PEER_SET_CHANGE, events.receive())
Expand Down
2 changes: 1 addition & 1 deletion lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub use self::{
progress::Progress,
protocol::{RepositoryId, StorageSize, BLOCK_SIZE},
repository::{
repository_files, Credentials, Metadata, Repository, RepositoryHandle, RepositoryParams,
database_files, Credentials, Metadata, Repository, RepositoryHandle, RepositoryParams,
},
store::{Error as StoreError, DATA_VERSION},
version_vector::VersionVector,
Expand Down
14 changes: 11 additions & 3 deletions lib/src/repository/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,17 @@ pub struct Repository {
progress_reporter_handle: BlockingMutex<Option<ScopedJoinHandle<()>>>,
}

/// Given a path to the main repository file, return paths to all the repository files (the main db
/// file and all auxiliary files).
pub fn repository_files(store_path: impl AsRef<Path>) -> Vec<PathBuf> {
/// List of repository database files. Includes the main db file and any auxiliary files.
///
/// The aux files don't always exists but when they do and one wants to rename, move or delete the
/// repository, they should rename/move/delete these files as well.
///
/// Note ideally the aux files should be deleted automatically when the repo has been closed. But
/// because of a [bug][1] in sqlx, they are sometimes not and need to be handled
/// (move,deleted) manually. When the bug is fixed, this function can be removed.
///
/// [1]: https://github.com/launchbadge/sqlx/issues/3217
pub fn database_files(store_path: impl AsRef<Path>) -> Vec<PathBuf> {
// Sqlite database consists of up to three files: main db (always present), WAL and WAL-index.
["", "-wal", "-shm"]
.into_iter()
Expand Down
57 changes: 16 additions & 41 deletions service/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
ffi::{c_char, c_void, CStr, CString},
ffi::{c_char, c_uchar, c_ulong, c_void, CStr, CString},
io, mem,
path::Path,
pin::pin,
Expand Down Expand Up @@ -166,6 +166,8 @@ fn init(
Ok((runtime, service, span))
}

pub type LogCallback = extern "C" fn(LogLevel, *const c_uchar, c_ulong, c_ulong);

/// Initialize logging. Should be called before `service_start`.
///
/// If `file` is not null, write log messages to the given file.
Expand All @@ -178,66 +180,33 @@ fn init(
///
/// `file` must be either null or it must be safe to pass to [std::ffi::CStr::from_ptr].
#[no_mangle]
pub unsafe extern "C" fn init_log(
file: *const c_char,
callback: Option<extern "C" fn(LogMessage)>,
) -> ErrorCode {
pub unsafe extern "C" fn init_log(file: *const c_char, callback: Option<LogCallback>) -> ErrorCode {
try_init_log(file, callback).to_error_code()
}

/// Release a log message back to the backend. See `init_log` for more details.
///
/// # Safety
///
/// `message` must have been obtained through the callback to `init_log` and not modified.
/// `ptr`, `len` and `cap` must have been obtained through the callback to `init_log` and not
/// modified.
#[no_mangle]
pub unsafe extern "C" fn release_log_message(message: LogMessage) {
let message = message.into_message();
pub unsafe extern "C" fn release_log_message(ptr: *const c_uchar, len: c_ulong, cap: c_ulong) {
let message = Vec::from_raw_parts(ptr as _, len as _, cap as _);

if let Some(pool) = LOGGER.get().and_then(|wrapper| wrapper.pool.as_ref()) {
pool.release(message);
}
}

#[repr(C)]
pub struct LogMessage {
level: LogLevel,
ptr: *const u8,
len: usize,
cap: usize,
}

impl LogMessage {
fn new(level: LogLevel, message: Vec<u8>) -> Self {
let ptr = message.as_ptr();
let len = message.len();
let cap = message.capacity();
mem::forget(message);

Self {
level,
ptr,
len,
cap,
}
}

unsafe fn into_message(self) -> Vec<u8> {
Vec::from_raw_parts(self.ptr as _, self.len, self.cap)
}
}

struct LoggerWrapper {
_logger: Logger,
pool: Option<BufferPool>,
}

static LOGGER: OnceLock<LoggerWrapper> = OnceLock::new();

unsafe fn try_init_log(
file: *const c_char,
callback: Option<extern "C" fn(LogMessage)>,
) -> Result<(), Error> {
unsafe fn try_init_log(file: *const c_char, callback: Option<LogCallback>) -> Result<(), Error> {
let builder = Logger::builder();
let builder = if !file.is_null() {
builder.file(Path::new(CStr::from_ptr(file).to_str()?))
Expand All @@ -248,7 +217,13 @@ unsafe fn try_init_log(
let (builder, pool) = if let Some(callback) = callback {
let pool = BufferPool::default();
let callback = Box::new(move |level, message: &mut Vec<u8>| {
callback(LogMessage::new(LogLevel::from(level), mem::take(message)));
let message = mem::take(message);
let ptr = message.as_ptr();
let len = message.len();
let cap = message.capacity();
mem::forget(message);

callback(LogLevel::from(level), ptr, len as _, cap as _);
});

(builder.callback(callback, pool.clone()), Some(pool))
Expand Down
Loading

0 comments on commit ba37d86

Please sign in to comment.