Skip to content

Commit

Permalink
voxel: Update assignment metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Ondraceq committed Dec 11, 2023
1 parent 525e4e4 commit 04b8c50
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 98 deletions.
116 changes: 79 additions & 37 deletions softwareComponents/voxelReconfig/src/reconfig/metric/assignment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
//! This limits the metric to evaluating only a pair of modules.
use super::centered_by_pos_median;
use super::{cost::Cost, Metric};
use crate::module_repr::{get_all_module_reprs, get_other_body};
use super::potential_fn::{FSqrtSum, PotentialFn};
use crate::module_repr::{get_all_module_reprs, get_other_body, is_module_repr};
use crate::pos::Pos;
use crate::voxel::{JointPosition, Voxel};
use crate::voxel::{JointPosition, PosVoxel, Voxel};
use crate::voxel_world::normalized_eq_worlds;
use crate::voxel_world::{CenteredVoxelWorld, NormVoxelWorld, VoxelWorld};
use iter_fixed::IntoIteratorFixed;
use ndarray::Array2;
use num::{Float, ToPrimitive};
use num::ToPrimitive;
use ordered_float::OrderedFloat;
use std::marker::PhantomData;

fn voxel_joint_diff<N: num::Zero + num::One>(goal: Voxel, other: Voxel) -> N {
match (goal.joint_pos(), other.joint_pos()) {
Expand Down Expand Up @@ -58,14 +59,13 @@ fn pos_cost<TIndex: num::Signed>(lhs: Pos<TIndex>, rhs: Pos<TIndex>) -> TIndex {
.fold(num::zero(), TIndex::add)
}

fn compute_best_mapping_cost<TWorld, TGoalWorld>(other: &TWorld, goal: &TGoalWorld) -> f32
fn compute_best_mapping_cost<TWorld, TGoalWorld, TCostFn>(other: &TWorld, goal: &TGoalWorld) -> f32
where
TWorld: VoxelWorld<IndexType = TGoalWorld::IndexType>,
TGoalWorld: VoxelWorld,
TWorld: VoxelWorld,
TGoalWorld: VoxelWorld<IndexType = TWorld::IndexType>,
TWorld::IndexType: ToPrimitive,
TCostFn: CostFn<TWorld::IndexType>,
{
// let other_modules = get_world_modules(other);
// let goal_modules = get_world_modules(goal);
let other_modules_count = get_all_module_reprs(other).count();
let goal_modules_count = get_all_module_reprs(goal).count();
assert_eq!(other_modules_count, goal_modules_count);
Expand All @@ -76,51 +76,89 @@ where
for (goal_body_a, j) in get_all_module_reprs(goal).zip(0..) {
let goal_body_b = get_other_body(goal_body_a, goal).unwrap();

// Comparing shoeA with shoeB will always have equal or higher pos cost
let pos_cost =
pos_cost(other_body_a.0, goal_body_a.0) + pos_cost(other_body_b.0, goal_body_b.0);
let pos_cost = ToPrimitive::to_f32(&pos_cost).unwrap();
let voxel_cost = f32::min(
voxel_diff(
[goal_body_a.1, goal_body_b.1],
[other_body_a.1, other_body_b.1],
),
voxel_diff(
[goal_body_b.1, goal_body_a.1],
[other_body_a.1, other_body_b.1],
),
);
costs[(i, j)] = pos_cost + voxel_cost;
costs[(i, j)] =
TCostFn::compute_cost([other_body_a, other_body_b], [goal_body_a, goal_body_b]);
}
}

let best_mapping = lapjv::lapjv(&costs).unwrap();
lapjv::cost(&costs, &best_mapping.0)
}

pub struct AssignmentMetric<TWorld>
pub trait CostFn<TIndex: num::Num> {
fn compute_cost(lhs_module: [PosVoxel<TIndex>; 2], rhs_module: [PosVoxel<TIndex>; 2]) -> f32;
}

pub struct PosCostFn;
impl<TIndex> CostFn<TIndex> for PosCostFn
where
TIndex: num::Signed + ToPrimitive + Copy,
{
fn compute_cost(lhs_module: [PosVoxel<TIndex>; 2], rhs_module: [PosVoxel<TIndex>; 2]) -> f32 {
debug_assert!(is_module_repr(lhs_module[0].1));
debug_assert!(is_module_repr(rhs_module[0].1));
// Comparing shoeA with shoeB will always have equal or higher pos cost
let pos_cost =
pos_cost(lhs_module[0].0, rhs_module[0].0) + pos_cost(lhs_module[1].0, rhs_module[1].0);
ToPrimitive::to_f32(&pos_cost).expect("Invalid position cost")
}
}

pub struct VoxelCostFn;
impl<TIndex> CostFn<TIndex> for VoxelCostFn
where
TIndex: num::Signed,
{
fn compute_cost(lhs_module: [PosVoxel<TIndex>; 2], rhs_module: [PosVoxel<TIndex>; 2]) -> f32 {
f32::min(
voxel_diff(
[lhs_module[0].1, lhs_module[1].1],
[rhs_module[0].1, rhs_module[1].1],
),
voxel_diff(
[lhs_module[1].1, lhs_module[0].1],
[rhs_module[0].1, rhs_module[1].1],
),
)
}
}

pub struct PosVoxelCostFn;
impl<TIndex> CostFn<TIndex> for PosVoxelCostFn
where
TIndex: num::Signed + ToPrimitive + Copy,
{
fn compute_cost(lhs_module: [PosVoxel<TIndex>; 2], rhs_module: [PosVoxel<TIndex>; 2]) -> f32 {
PosCostFn::compute_cost(lhs_module, rhs_module)
+ VoxelCostFn::compute_cost(lhs_module, rhs_module)
}
}

pub struct AssignmentPotentialFn<TWorld, TCostFn>
where
TWorld: VoxelWorld,
TCostFn: CostFn<TWorld::IndexType>,
{
goals: Vec<CenteredVoxelWorld<TWorld, TWorld>>,
__phantom: PhantomData<TCostFn>,
}

impl<TWorld> Metric<TWorld> for AssignmentMetric<TWorld>
impl<TWorld, TCostFn> PotentialFn<TWorld> for AssignmentPotentialFn<TWorld, TCostFn>
where
TWorld: NormVoxelWorld,
TWorld::IndexType: ToPrimitive,
TCostFn: CostFn<TWorld::IndexType>,
{
type Potential = OrderedFloat<f32>;
type EstimatedCost = Self::Potential;

fn new(goal: &TWorld) -> Self
where
Self: Sized,
{
fn new(goal: &TWorld) -> Self {
let goals = normalized_eq_worlds(goal)
.map(centered_by_pos_median)
.collect();
Self { goals }
Self {
goals,
__phantom: Default::default(),
}
}

fn get_potential(&mut self, state: &TWorld) -> Self::Potential {
Expand All @@ -129,13 +167,17 @@ where
assert!(!self.goals.is_empty());
self.goals
.iter()
.map(|goal| compute_best_mapping_cost(&state, goal))
.map(|goal| compute_best_mapping_cost::<_, _, TCostFn>(&state, goal))
.map(OrderedFloat)
.min()
.unwrap()
}

fn estimated_cost(cost: Cost<Self::Potential>) -> Self::EstimatedCost {
cost.potential + OrderedFloat(cost.real_cost as f32).sqrt()
}
}

pub type PosAssgPotFn<TWorld> = AssignmentPotentialFn<TWorld, PosCostFn>;
pub type VoxelAssgPotFn<TWorld> = AssignmentPotentialFn<TWorld, VoxelCostFn>;
pub type PosVoxelAssgPotFn<TWorld> = AssignmentPotentialFn<TWorld, PosVoxelCostFn>;

pub type PosAssgMetric<TWorld> = FSqrtSum<TWorld, PosAssgPotFn<TWorld>>;
pub type VoxelAssgMetric<TWorld> = FSqrtSum<TWorld, VoxelAssgPotFn<TWorld>>;
pub type PosVoxelAssgMetric<TWorld> = FSqrtSum<TWorld, PosVoxelAssgPotFn<TWorld>>;
6 changes: 2 additions & 4 deletions softwareComponents/voxelReconfig/src/reconfig/metric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod assignment;
pub mod cost;
pub mod naive;
pub mod potential_fn;

use self::cost::Cost;
use crate::pos::Pos;
Expand All @@ -24,10 +25,7 @@ impl<TState> Metric<TState> for ZeroMetric {
type Potential = ();
type EstimatedCost = usize;

fn new(_goal: &TState) -> Self
where
Self: Sized,
{
fn new(_goal: &TState) -> Self {
Self
}
fn get_potential(&mut self, _state: &TState) -> Self::Potential {}
Expand Down
17 changes: 8 additions & 9 deletions softwareComponents/voxelReconfig/src/reconfig/metric/naive.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
//! Metric based on finding mappings by going through all mappings
//! and evaluating the difference of worlds fixed to these mappings.
//!
//! Cannot use the best assignment algorith since it compares
//! the connection graphs of the configurations as well.
use self::graph::{Graph, Mapping, Node};
use super::{cost::Cost, Metric};
use super::potential_fn::{ISqrtSum, PotentialFn};
use crate::voxel::{JointPosition, Voxel};
use crate::voxel_world::NormVoxelWorld;
use num::integer::Roots;
use std::marker::PhantomData;

pub struct NaiveMetric<TWorld>
pub struct NaivePotential<TWorld>
where
TWorld: NormVoxelWorld,
{
goal: Graph<TWorld::IndexType>,
__phantom: PhantomData<TWorld>,
}
pub type NaiveMetric<TWorld> = ISqrtSum<TWorld, NaivePotential<TWorld>>;

impl<TWorld> NaiveMetric<TWorld>
impl<TWorld> NaivePotential<TWorld>
where
TWorld: NormVoxelWorld,
{
Expand Down Expand Up @@ -80,7 +83,7 @@ where
}
}

impl<TWorld> Metric<TWorld> for NaiveMetric<TWorld>
impl<TWorld> PotentialFn<TWorld> for NaivePotential<TWorld>
where
TWorld: NormVoxelWorld,
{
Expand All @@ -99,10 +102,6 @@ where
fn get_potential(&mut self, state: &TWorld) -> Self::Potential {
self.compute_best_potential(state)
}

fn estimated_cost(cost: Cost<Self::Potential>) -> Self::EstimatedCost {
cost.potential + cost.real_cost.sqrt()
}
}

pub mod graph {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use super::Metric;
use num::integer::Roots;
use num::traits::real::Real;
use num::FromPrimitive;
use std::marker::PhantomData;

pub trait PotentialFn<TState> {
type Potential: std::cmp::Ord + Default + Copy + std::fmt::Debug = usize;

fn new(goal: &TState) -> Self
where
Self: Sized;

fn get_potential(&mut self, state: &TState) -> Self::Potential;
}

// Uses the `potential + real_cost` for the estimated cost
pub struct Sum<TState, TPotentialFn>(pub TPotentialFn, PhantomData<TState>)
where
TPotentialFn: PotentialFn<TState>;

impl<TState, TPotentialFn> Metric<TState> for Sum<TState, TPotentialFn>
where
TPotentialFn: PotentialFn<TState>,
TPotentialFn::Potential: FromPrimitive + num::Num,
{
type Potential = TPotentialFn::Potential;
type EstimatedCost = Self::Potential;

fn new(goal: &TState) -> Self {
Self(PotentialFn::new(goal), Default::default())
}

fn get_potential(&mut self, state: &TState) -> Self::Potential {
self.0.get_potential(state)
}

fn estimated_cost(cost: super::cost::Cost<Self::Potential>) -> Self::EstimatedCost {
cost.potential
+ FromPrimitive::from_usize(cost.real_cost).expect("Cannot convert real cost")
}
}

// Uses the `potential + integer_sqrt(real_cost)` for the estimated cost
pub struct ISqrtSum<TState, TPotentialFn>(pub TPotentialFn, PhantomData<TState>)
where
TPotentialFn: PotentialFn<TState>;

impl<TState, TPotentialFn> Metric<TState> for ISqrtSum<TState, TPotentialFn>
where
TPotentialFn: PotentialFn<TState>,
TPotentialFn::Potential: num::Integer + FromPrimitive,
{
type Potential = TPotentialFn::Potential;
type EstimatedCost = TPotentialFn::Potential;

fn new(goal: &TState) -> Self {
Self(PotentialFn::new(goal), Default::default())
}

fn get_potential(&mut self, state: &TState) -> Self::Potential {
self.0.get_potential(state)
}

fn estimated_cost(cost: super::cost::Cost<Self::Potential>) -> Self::EstimatedCost {
cost.potential
+ FromPrimitive::from_usize(cost.real_cost.sqrt()).expect("Cannot convert real cost")
}
}

// Uses the `potential + sqrt(real_cost)` for the estimated cost
pub struct FSqrtSum<TState, TPotentialFn>(pub TPotentialFn, PhantomData<TState>)
where
TPotentialFn: PotentialFn<TState>;

impl<TState, TPotentialFn> Metric<TState> for FSqrtSum<TState, TPotentialFn>
where
TPotentialFn: PotentialFn<TState>,
TPotentialFn::Potential: FromPrimitive + Real,
{
type Potential = TPotentialFn::Potential;
type EstimatedCost = TPotentialFn::Potential;

fn new(goal: &TState) -> Self {
Self(PotentialFn::new(goal), Default::default())
}

fn get_potential(&mut self, state: &TState) -> Self::Potential {
self.0.get_potential(state)
}

fn estimated_cost(cost: super::cost::Cost<Self::Potential>) -> Self::EstimatedCost {
let real_cost =
Self::EstimatedCost::from_usize(cost.real_cost).expect("Cannot convert real cost");
cost.potential + real_cost.sqrt()
}
}
9 changes: 6 additions & 3 deletions tools/voxel/rofi-voxel
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,14 @@ def check_input_files(files: List[str]):
[
"bfs",
"astar-zero",
"astar-zero-opt",
"astar-naive",
"astar-naive-opt",
"astar-assignment",
"astar-assignment-opt",
"astar-assg-posvoxel",
"astar-assg-posvoxel-opt",
"astar-assg-pos",
"astar-assg-pos-opt",
"astar-assg-voxel",
"astar-assg-voxel-opt",
]
),
default="bfs",
Expand Down
Loading

0 comments on commit 04b8c50

Please sign in to comment.