Skip to content

Commit

Permalink
feat: add generalized totalizer to capi
Browse files Browse the repository at this point in the history
  • Loading branch information
chrjabs committed Sep 3, 2024
1 parent f519ee7 commit af19b19
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
161 changes: 161 additions & 0 deletions capi/src/encodings/gte.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//! # Generatlized Totalizer C-API
use std::ffi::{c_int, c_void};

use rustsat::{
encodings::pb::{BoundUpper, BoundUpperIncremental, DbGte},
types::Lit,
};

use super::{CAssumpCollector, CClauseCollector, ClauseCollector, MaybeError, VarManager};

/// Creates a new [`DbGte`] cardinality encoding
#[no_mangle]
pub extern "C" fn gte_new() -> *mut DbGte {
Box::into_raw(Box::default())
}

/// Adds a new input literal to a [`DbGte`].
///
/// # Errors
///
/// - If `lit` is not a valid IPASIR-style literal (e.g., `lit = 0`),
/// [`MaybeError::InvalidLiteral`] is returned
///
/// # Safety
///
/// `gte` must be a return value of [`gte_new`] that [`gte_drop`] has not yet been called on.
#[no_mangle]
pub unsafe extern "C" fn gte_add(gte: *mut DbGte, lit: c_int, weight: usize) -> MaybeError {
let Ok(lit) = Lit::from_ipasir(lit) else {
return MaybeError::InvalidLiteral;
};
unsafe { (*gte).extend([(lit, weight)]) };
MaybeError::Ok
}

/// Lazily builds the _change in_ pseudo-boolean encoding to enable upper bounds from within the
/// range.
///
/// The min and max bounds are inclusive. After a call to [`gte_encode_ub`] with `min_bound=2` and
/// `max_bound=4`, bounds satisfying `2 <= bound <= 4` can be enforced.
///
/// Clauses are returned via the `collector`. The `collector` function should expect clauses to be
/// passed similarly to `ipasir_add`, as a 0-terminated sequence of literals where the literals are
/// passed as the first argument and the `collector_data` as a second.
///
/// `n_vars_used` must be the number of variables already used and will be incremented by the
/// number of variables used up in the encoding.
///
/// # Safety
///
/// `gte` must be a return value of [`gte_new`] that [`gte_drop`] has not yet been called on.
#[no_mangle]
pub unsafe extern "C" fn gte_encode_ub(
gte: *mut DbGte,
min_bound: usize,
max_bound: usize,
n_vars_used: &mut u32,
collector: CClauseCollector,
collector_data: *mut c_void,
) {
assert!(min_bound <= max_bound);
let mut collector = ClauseCollector::new(collector, collector_data);
let mut var_manager = VarManager::new(n_vars_used);
unsafe { (*gte).encode_ub_change(min_bound..=max_bound, &mut collector, &mut var_manager) }
.expect("clause collector returned out of memory");
}

/// Returns assumptions/units for enforcing an upper bound (`sum of lits <= ub`). Make sure that
/// [`gte_encode_ub`] has been called adequately and nothing has been called afterwards, otherwise
/// [`MaybeError::NotEncoded`] will be returned.
///
/// Assumptions are returned via the collector callback. There is _no_ terminating zero, all
/// assumptions are passed when [`gte_enforce_ub`] returns.
///
/// # Safety
///
/// `gte` must be a return value of [`gte_new`] that [`gte_drop`] has not yet been called on.
#[no_mangle]
pub unsafe extern "C" fn gte_enforce_ub(
gte: *mut DbGte,
ub: usize,
collector: CAssumpCollector,
collector_data: *mut c_void,
) -> MaybeError {
match unsafe { (*gte).enforce_ub(ub) } {
Ok(assumps) => {
for l in assumps {
collector(l.to_ipasir(), collector_data);
}
MaybeError::Ok
}
Err(err) => err.into(),
}
}

/// Frees the memory associated with a [`DbGte`]
///
/// # Safety
///
/// `gte` must be a return value of [`gte_new`] and cannot be used
/// afterwards again.
#[no_mangle]
pub unsafe extern "C" fn gte_drop(gte: *mut DbGte) {
drop(unsafe { Box::from_raw(gte) });
}

// TODO: figure out how to get these to work on windows
#[cfg(all(test, not(target_os = "windows")))]
mod tests {
use inline_c::assert_c;

#[test]
fn new_drop() {
(assert_c! {
#include <assert.h>
#include "rustsat.h"

int main() {
DbGte *gte = gte_new();
assert(gte != NULL);
gte_drop(gte);
return 0;
}
})
.success();
}

#[test]
fn basic() {
(assert_c! {
#include <assert.h>
#include <stdio.h>
#include "rustsat.h"

void clause_counter(int lit, void *data) {
if (!lit) {
int *cnt = (int *)data;
(*cnt)++;
}
}

int main() {
DbGte *gte = gte_new();
assert(gte_add(gte, 1, 1) == Ok);
assert(gte_add(gte, 2, 2) == Ok);
assert(gte_add(gte, 3, 3) == Ok);
assert(gte_add(gte, 4, 4) == Ok);
uint32_t n_used = 4;
uint32_t n_clauses = 0;
gte_encode_ub(gte, 0, 6, &n_used, &clause_counter, &n_clauses);
gte_drop(gte);
printf("%d", n_used);
assert(n_used == 24);
assert(n_clauses == 25);
return 0;
}
})
.success();
}
}
16 changes: 16 additions & 0 deletions src/encodings/pb/dbgte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,22 @@ mod tests {
assert_eq!(cnf2.len(), gte2.n_clauses());
}

#[test]
fn from_capi() {
let mut gte1 = DbGte::default();
let mut lits = RsHashMap::default();
lits.insert(lit![0], 1);
lits.insert(lit![1], 2);
lits.insert(lit![2], 3);
lits.insert(lit![3], 4);
gte1.extend(lits);
let mut var_manager = BasicVarManager::from_next_free(var![4]);
let mut cnf = Cnf::new();
gte1.encode_ub(0..=6, &mut cnf, &mut var_manager).unwrap();
debug_assert_eq!(var_manager.n_used(), 24);
debug_assert_eq!(cnf.len(), 25);
}

#[test]
fn ub_gte_multiplication() {
let mut gte1 = DbGte::default();
Expand Down

0 comments on commit af19b19

Please sign in to comment.