Skip to content

Commit

Permalink
Refactor flaw and resolver creation to use std::make_unique; update z…
Browse files Browse the repository at this point in the history
…3 classes to use atom_expr type
  • Loading branch information
riccardodebenedictis committed Jan 24, 2025
1 parent d63f0f5 commit 36cfb1a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 33 deletions.
24 changes: 13 additions & 11 deletions include/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ namespace ratio
Tp &new_flaw(Args &&...args) noexcept
{
static_assert(std::is_base_of_v<flaw, Tp>, "Tp must be a subclass of flaw");
auto f = new Tp(std::forward<Args>(args)...);
NEW_FLAW(*f);
flaws.emplace_back(std::unique_ptr<flaw>(f));
flaw_q.push_back(*f); // add to the flaw queue..
if (f->get_causes().empty())
root_flaws.push_back(*f); // add to the root-level flaws..
return *f;
auto f = std::make_unique<Tp>(std::forward<Args>(args)...);
auto &f_ref = *f;
NEW_FLAW(f_ref);
flaws.emplace_back(std::move(f));
flaw_q.push_back(f_ref); // add to the flaw queue..
if (f_ref.get_causes().empty())
root_flaws.push_back(f_ref); // add to the root-level flaws..
return f_ref;
}

/**
Expand All @@ -78,10 +79,11 @@ namespace ratio
Tp &new_resolver(Args &&...args) noexcept
{
static_assert(std::is_base_of_v<resolver, Tp>, "Tp must be a subclass of resolver");
auto r = new Tp(std::forward<Args>(args)...);
NEW_RESOLVER(*r);
resolvers.emplace_back(std::unique_ptr<resolver>(r));
return *r;
auto r = std::make_unique<Tp>(std::forward<Args>(args)...);
auto &r_ref = *r;
NEW_RESOLVER(r_ref);
resolvers.emplace_back(std::move(r));
return r_ref;
}

[[nodiscard]] std::vector<std::reference_wrapper<flaw>> get_flaws() const noexcept;
Expand Down
16 changes: 8 additions & 8 deletions include/z3/z3flaws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace ratio
{
public:
z3resolver(flaw &f, utils::rational &&intrinsic_cost) noexcept;
z3resolver(flaw &f, utils::rational &&intrinsic_cost, z3::expr &&rho) noexcept;
z3resolver(flaw &f, utils::rational &&intrinsic_cost, z3::expr rho) noexcept;

[[nodiscard]] z3::expr &get_rho() noexcept { return rho; }
[[nodiscard]] const z3::expr &get_rho() const noexcept { return rho; }
Expand All @@ -52,23 +52,23 @@ namespace ratio
public:
z3atom_flaw(z3solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, bool is_fact, riddle::predicate &pred, std::map<std::string, std::shared_ptr<riddle::item>, std::less<>> &&args) noexcept;

[[nodiscard]] riddle::atom_expr &get_atom() noexcept { return atm; }
[[nodiscard]] const riddle::atom_expr &get_atom() const noexcept { return atm; }
[[nodiscard]] atom_expr &get_atom() noexcept { return atm; }
[[nodiscard]] const atom_expr &get_atom() const noexcept { return atm; }

private:
void compute_resolvers() override;

json::json to_json() const override;

private:
riddle::atom_expr atm; // the atom that is the subject of the flaw..
atom_expr atm; // the atom that is the subject of the flaw..
};

class z3activate_fact final : public z3resolver
{
public:
z3activate_fact(z3atom_flaw &f) noexcept;
z3activate_fact(z3atom_flaw &f, z3::expr &&rho) noexcept;
z3activate_fact(z3atom_flaw &f, z3::expr rho) noexcept;

private:
void apply() override;
Expand All @@ -80,7 +80,7 @@ namespace ratio
{
public:
z3activate_goal(z3atom_flaw &f) noexcept;
z3activate_goal(z3atom_flaw &f, z3::expr &&rho) noexcept;
z3activate_goal(z3atom_flaw &f, z3::expr rho) noexcept;

private:
void apply() override;
Expand All @@ -91,15 +91,15 @@ namespace ratio
class z3unify_atom final : public z3resolver
{
public:
z3unify_atom(z3atom_flaw &f, riddle::atom_expr atm) noexcept;
z3unify_atom(z3atom_flaw &f, atom_expr atm) noexcept;

private:
void apply() override;

json::json to_json() const override;

private:
riddle::atom_expr atm; // the atom to unify with..
atom_expr atm; // the atom to unify with..
};

class z3disjunction_flaw final : public z3flaw
Expand Down
11 changes: 7 additions & 4 deletions include/z3/z3solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace ratio
{
class z3flaw;
class z3atom_flaw;
class z3resolver;
class z3component_type;
class z3unify_atom;
Expand Down Expand Up @@ -79,9 +80,9 @@ namespace ratio
class atom : public riddle::atom
{
public:
atom(z3flaw &flaw, riddle::predicate &pred, bool is_fact, std::map<std::string, std::shared_ptr<riddle::item>, std::less<>> &&args);
atom(z3atom_flaw &flaw, riddle::predicate &pred, bool is_fact, std::map<std::string, std::shared_ptr<riddle::item>, std::less<>> &&args);

[[nodiscard]] z3flaw &get_flaw() noexcept { return flaw; }
[[nodiscard]] z3atom_flaw &get_flaw() noexcept { return flaw; }

[[nodiscard]] z3::expr &get_sigma() noexcept { return sigma; }
[[nodiscard]] const z3::expr &get_sigma() const noexcept { return sigma; }
Expand All @@ -93,10 +94,12 @@ namespace ratio
[[nodiscard]] json::json to_json() const override;

private:
z3flaw &flaw; // the flaw associated with this atom..
z3::expr sigma; // the activation status of the atom (i.e., 0 if inactive, 1 if active, 2 if unified)....
z3atom_flaw &flaw; // the flaw associated with this atom..
z3::expr sigma; // the activation status of the atom (i.e., 0 if inactive, 1 if active, 2 if unified)....
};

using atom_expr = std::shared_ptr<atom>;

class z3solver : public graph
{
friend class bool_item;
Expand Down
19 changes: 10 additions & 9 deletions src/z3/z3flaws.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "z3flaws.hpp"
#include "conjunction.hpp"
#include "logging.hpp"

namespace ratio
{
Expand Down Expand Up @@ -33,7 +34,7 @@ namespace ratio
}

z3resolver::z3resolver(flaw &f, utils::rational &&intrinsic_cost) noexcept : resolver(f, std::move(intrinsic_cost)), rho(static_cast<z3solver &>(f.get_graph()).ctx.bool_const(("b" + std::to_string(static_cast<z3solver &>(f.get_graph()).bool_count++)).c_str())) {}
z3resolver::z3resolver(flaw &f, utils::rational &&intrinsic_cost, z3::expr &&rho) noexcept : resolver(f, std::move(intrinsic_cost)), rho(std::move(rho)) {}
z3resolver::z3resolver(flaw &f, utils::rational &&intrinsic_cost, z3::expr rho) noexcept : resolver(f, std::move(intrinsic_cost)), rho(rho) {}

void z3resolver::add(const z3::expr &e) { static_cast<z3solver &>(get_flaw().get_graph()).slv.add(z3::implies(rho, e)); }

Expand All @@ -50,15 +51,15 @@ namespace ratio
{
for (auto unf_atm : static_cast<riddle::predicate &>(atm->get_type()).get_atoms())
if (unf_atm.get() != atm.get() && static_cast<atom *>(unf_atm.get())->get_flaw().is_expanded())
new_resolver<z3unify_atom>(*this, unf_atm);
new_resolver<z3unify_atom>(*this, std::dynamic_pointer_cast<atom>(unf_atm));

if (atm->is_fact())
if (get_resolvers().empty())
new_resolver<z3activate_fact>(*this, z3::expr(get_phi()));
new_resolver<z3activate_fact>(*this, get_phi());
else
new_resolver<z3activate_fact>(*this);
else if (get_resolvers().empty())
new_resolver<z3activate_goal>(*this, z3::expr(get_phi()));
new_resolver<z3activate_goal>(*this, get_phi());
else
new_resolver<z3activate_goal>(*this);
}
Expand All @@ -71,7 +72,7 @@ namespace ratio
}

z3activate_fact::z3activate_fact(z3atom_flaw &f) noexcept : z3resolver(f, utils::rational(1)) {}
z3activate_fact::z3activate_fact(z3atom_flaw &f, z3::expr &&rho) noexcept : z3resolver(f, utils::rational(1), std::move(rho)) {}
z3activate_fact::z3activate_fact(z3atom_flaw &f, z3::expr rho) noexcept : z3resolver(f, utils::rational(1), rho) {}
void z3activate_fact::apply()
{
// activating the resolver means activating the atom..
Expand All @@ -85,13 +86,13 @@ namespace ratio
}

z3activate_goal::z3activate_goal(z3atom_flaw &f) noexcept : z3resolver(f, utils::rational(1)) {}
z3activate_goal::z3activate_goal(z3atom_flaw &f, z3::expr &&rho) noexcept : z3resolver(f, utils::rational(1), std::move(rho)) {}
z3activate_goal::z3activate_goal(z3atom_flaw &f, z3::expr rho) noexcept : z3resolver(f, utils::rational(1), rho) {}
void z3activate_goal::apply()
{
// activating the resolver means activating the atom..
add(static_cast<atom &>(*static_cast<z3atom_flaw &>(get_flaw()).get_atom()).get_sigma() == 1);
// we also call the corresponding rule..
static_cast<riddle::predicate &>(static_cast<atom &>(*static_cast<z3atom_flaw &>(get_flaw()).get_atom()).get_type()).call(static_cast<z3atom_flaw &>(get_flaw()).get_atom());
static_cast<riddle::predicate &>(static_cast<z3atom_flaw &>(get_flaw()).get_atom()->get_type()).call(static_cast<z3atom_flaw &>(get_flaw()).get_atom());
}
json::json z3activate_goal::to_json() const
{
Expand All @@ -100,7 +101,7 @@ namespace ratio
return j_resolver;
}

z3unify_atom::z3unify_atom(z3atom_flaw &f, riddle::atom_expr atm) noexcept : z3resolver(f, utils::rational(1)), atm(atm) {}
z3unify_atom::z3unify_atom(z3atom_flaw &f, atom_expr atm) noexcept : z3resolver(f, utils::rational(1)), atm(atm) {}
void z3unify_atom::apply()
{
// unifying the atom means unifying the atoms..
Expand All @@ -112,7 +113,7 @@ namespace ratio
add(static_cast<bool_item &>(*eq).get_expr());

// we also add the corresponding causal link..
static_cast<z3solver &>(get_flaw().get_graph()).add_causal_link(static_cast<atom &>(*atm).get_flaw(), *this);
static_cast<z3solver &>(get_flaw().get_graph()).add_causal_link(atm->get_flaw(), *this);
}
json::json z3unify_atom::to_json() const
{
Expand Down
3 changes: 2 additions & 1 deletion src/z3/z3solver.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "z3solver.hpp"
#include "z3flaws.hpp"
#include "z3types.hpp"
#include "conjunction.hpp"
#include "init.hpp"
#include "logging.hpp"
#include <queue>
Expand Down Expand Up @@ -175,7 +176,7 @@ namespace ratio
return j;
}

atom::atom(z3flaw &flaw, riddle::predicate &pred, bool is_fact, std::map<std::string, std::shared_ptr<riddle::item>, std::less<>> &&args) : riddle::atom(pred, is_fact, std::move(args)), flaw(flaw), sigma(static_cast<z3solver &>(get_core()).ctx.int_const(("a" + std::to_string(static_cast<z3solver &>(get_core()).atom_count++)).c_str()))
atom::atom(z3atom_flaw &flaw, riddle::predicate &pred, bool is_fact, std::map<std::string, std::shared_ptr<riddle::item>, std::less<>> &&args) : riddle::atom(pred, is_fact, std::move(args)), flaw(flaw), sigma(static_cast<z3solver &>(get_core()).ctx.int_const(("a" + std::to_string(static_cast<z3solver &>(get_core()).atom_count++)).c_str()))
{
static_cast<z3solver &>(get_core()).slv.add(sigma >= static_cast<z3solver &>(get_core()).ctx.int_val(0));
static_cast<z3solver &>(get_core()).slv.add(sigma < static_cast<z3solver &>(get_core()).ctx.int_val(2));
Expand Down

0 comments on commit 36cfb1a

Please sign in to comment.