Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(kernels::grisubal): rewrite step 1 to enable parallelization #210

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion honeycomb-kernels/src/grisubal/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
///
/// Cells `(X, Y)` take value in range `(0, 0)` to `(N, M)`,
/// from left to right (X), from bottom to top (Y).
#[derive(PartialEq)]
#[derive(PartialEq, Clone, Copy)]
pub struct GridCellId(pub usize, pub usize);

impl GridCellId {
Expand Down
155 changes: 90 additions & 65 deletions honeycomb-kernels/src/grisubal/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,42 +193,65 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
[cx, cy]: [T; 2],
origin: Vertex2<T>,
) -> (Segments, Vec<(DartIdentifier, T)>) {
let mut intersection_metadata = Vec::new();
let mut new_segments = HashMap::with_capacity(geometry.poi.len() * 2); // that *2 has no basis
geometry.segments.iter().for_each(|&(v1_id, v2_id)| {
// fetch vertices of the segment
let Vertex2(ox, oy) = origin;
let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]);
// compute their position in the grid
// we assume that the origin of the grid is at (0., 0.)
let (c1, c2) = (
GridCellId(
((v1.x() - ox) / cx).floor().to_usize().unwrap(),
((v1.y() - oy) / cy).floor().to_usize().unwrap(),
),
GridCellId(
((v2.x() - ox) / cx).floor().to_usize().unwrap(),
((v2.y() - oy) / cy).floor().to_usize().unwrap(),
),
);
let tmp: Vec<_> = geometry
.segments
.iter()
.map(|&(v1_id, v2_id)| {
// fetch vertices of the segment
let Vertex2(ox, oy) = origin;
let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]);
// compute their position in the grid
// we assume that the origin of the grid is at (0., 0.)
let (c1, c2) = (
GridCellId(
((v1.x() - ox) / cx).floor().to_usize().unwrap(),
((v1.y() - oy) / cy).floor().to_usize().unwrap(),
),
GridCellId(
((v2.x() - ox) / cx).floor().to_usize().unwrap(),
((v2.y() - oy) / cy).floor().to_usize().unwrap(),
),
);
(
GridCellId::man_dist(&c1, &c2),
GridCellId::diff(&c1, &c2),
v1,
v2,
v1_id,
v2_id,
c1,
)
})
.collect();
let n_intersec: usize = tmp.iter().map(|(dist, _, _, _, _, _, _)| dist).sum();
let prefix_sum: Vec<usize> = (0..tmp.len())
.map(|i| (0..i).map(|idx| tmp[idx].0).sum())
.collect();
imrn99 marked this conversation as resolved.
Show resolved Hide resolved
let intersec_ids: Vec<_> = tmp
.iter()
.map(|(dist, _, _, _, _, _, _)| dist)
.zip(prefix_sum.iter())
.map(|(&n_i, &start)| start..start + n_i)
.collect();
imrn99 marked this conversation as resolved.
Show resolved Hide resolved
let mut intersection_metadata = vec![(NULL_DART_ID, T::nan()); n_intersec];
let new_segments: Segments = tmp.iter().zip(intersec_ids).flat_map(|(&(dist, diff, v1, v2, v1_id, v2_id, c1), i_ids)| {
let transform = Box::new(|seg: &[GeometryVertex]| {
assert_eq!(seg.len(), 2);
(seg[0].clone(), seg[1].clone())
});
// check neighbor status
match GridCellId::man_dist(&c1, &c2) {
match dist {
// trivial case:
// v1 & v2 belong to the same cell
0 => {
new_segments.insert(
make_geometry_vertex!(geometry, v1_id),
make_geometry_vertex!(geometry, v2_id),
);
vec![(make_geometry_vertex!(geometry, v1_id), make_geometry_vertex!(geometry, v2_id))]
}
// ok case:
// v1 & v2 belong to neighboring cells
1 => {
// fetch base dart of the cell of v1
#[allow(clippy::cast_possible_truncation)]
let d_base = (1 + 4 * c1.0 + nx * 4 * c1.1) as DartIdentifier;
// which edge of the cell are we intersecting?
let diff = GridCellId::diff(&c1, &c2);
// which dart does this correspond to?
#[rustfmt::skip]
let dart_id = match diff {
Expand All @@ -253,25 +276,18 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
_ => unreachable!(),
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));

new_segments.insert(
make_geometry_vertex!(geometry, v1_id),
GeometryVertex::Intersec(id),
);
new_segments.insert(
GeometryVertex::Intersec(id),
make_geometry_vertex!(geometry, v2_id),
);
debug_assert_eq!(i_ids.len(), 1);
let id = i_ids.start;
intersection_metadata[id] = (dart_id, t);

vec![
(make_geometry_vertex!(geometry, v1_id), GeometryVertex::Intersec(id)),
(GeometryVertex::Intersec(id), make_geometry_vertex!(geometry, v2_id)),
]
}
// highly annoying case:
// v1 & v2 do not belong to neighboring cell
_ => {
// because we're using strait segments (not curves), the manhattan distance gives us
// the number of cell we're going through to reach v2 from v1
let diff = GridCellId::diff(&c1, &c2);
// pure vertical / horizontal traversal are treated separately because it ensures we're not trying
// to compute intersections of parallel segments (which results at best in a division by 0)
match diff {
Expand All @@ -284,7 +300,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
// i > 0: i_base..i_base + i
// or
// i < 0: i_base + 1 + i..i_base + 1
(min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).map(|x| {
(min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).zip(i_ids).map(|(x, id)| {
// cell base dart
let d_base =
(1 + 4 * x + (nx * 4 * c1.1) as isize) as DartIdentifier;
Expand All @@ -304,24 +320,27 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
left_intersec!(v1, v2, v_dart, cy)
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));
intersection_metadata[id] = (dart_id, t);

GeometryVertex::Intersec(id)
});

// because of how the range is written, we need to reverse the iterator in one case
// to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct)
let mut vs: VecDeque<GeometryVertex> = if i > 0 {
tmp.collect()
} else {
tmp.rev().collect()
};

// complete the vertex list
vs.push_front(make_geometry_vertex!(geometry, v1_id));
vs.push_back(make_geometry_vertex!(geometry, v2_id));
vs.make_contiguous().windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.make_contiguous()
.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
(0, j) => {
// we can solve the intersection equation
Expand All @@ -332,7 +351,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
// j > 0: j_base..j_base + j
// or
// j < 0: j_base + 1 + j..j_base + 1
(min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).map(|y| {
(min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).zip(i_ids).map(|(y, id)| {
// cell base dart
let d_base = (1 + 4 * c1.0 + nx * 4 * y as usize) as DartIdentifier;
// intersected dart
Expand All @@ -347,26 +366,27 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
down_intersec!(v1, v2, v_dart, cx)
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));
intersection_metadata[id] = (dart_id, t);

GeometryVertex::Intersec(id)
});

// because of how the range is written, we need to reverse the iterator in one case
// to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct)
let mut vs: VecDeque<GeometryVertex> = if j > 0 {
tmp.collect()
} else {
tmp.rev().collect()
};

// complete the vertex list
vs.push_front(make_geometry_vertex!(geometry, v1_id));
vs.push_back(make_geometry_vertex!(geometry, v2_id));
// insert new segments
vs.make_contiguous().windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.make_contiguous()
.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
(i, j) => {
// in order to process this, we'll consider a "sub-grid" & use the direction of the segment to
Expand Down Expand Up @@ -454,6 +474,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
None
})
.collect();

// sort intersections from v1 to v2
intersec_data.retain(|(s, _, _)| (T::zero() <= *s) && (*s <= T::one()));
// panic unreachable because of the retain above; there's no s s.t. s == NaN
Expand All @@ -462,31 +483,34 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(

// collect geometry vertices
let mut vs = vec![make_geometry_vertex!(geometry, v1_id)];
vs.extend(intersec_data.iter_mut().map(|(_, t, dart_id)| {
vs.extend(intersec_data.iter_mut().zip(i_ids).map(|((_, t, dart_id), id)| {
if t.is_zero() {
// we assume that the segment fully goes through the corner and does not land exactly
// on it, this allows us to compute directly the dart from which the next segment
// should start: the one incident to the vertex in the opposite quadrant

// in that case, the preallocated intersection metadata slot will stay as (0, Nan)
// this is ok, we can simply ignore the entry when processing the data later

let dart_in = *dart_id;
GeometryVertex::IntersecCorner(dart_in)
} else {
// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((*dart_id, *t));
intersection_metadata[id] = (*dart_id, *t);

GeometryVertex::Intersec(id)
}
}));

vs.push(make_geometry_vertex!(geometry, v2_id));
// insert segments
vs.windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
}
}
};
});
}
}).collect();
(new_segments, intersection_metadata)
}

Expand All @@ -499,6 +523,7 @@ pub(super) fn group_intersections_per_edge<T: CoordsFloat>(
HashMap::new();
intersection_metadata
.into_iter()
.filter(|(_, t)| !t.is_nan())
.enumerate()
.for_each(|(idx, (dart_id, mut t))| {
// classify intersections per edge_id & adjust t if needed
Expand Down
13 changes: 11 additions & 2 deletions honeycomb-kernels/src/grisubal/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ fn regular_intersections() {
generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default());

assert_eq!(intersection_metadata.len(), 4);
// FIXME: INDEX ACCESSES WON'T WORK IN PARALLEL
assert_eq!(intersection_metadata[0], (2, 0.5));
assert_eq!(intersection_metadata[1], (7, 0.5));
assert_eq!(intersection_metadata[2], (16, 0.5));
Expand Down Expand Up @@ -261,6 +260,8 @@ fn regular_intersections() {

#[test]
fn corner_intersection() {
use num_traits::Float;

let mut cmap = CMapBuilder::from(
GridDescriptor::default()
.len_per_cell([1.0; 3])
Expand All @@ -280,7 +281,15 @@ fn corner_intersection() {
let (segments, intersection_metadata) =
generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default());

assert_eq!(intersection_metadata.len(), 2);
// because we intersec a corner, some entries were preallocated but not needed.
// entries were initialized with (0, Nan), so they're easy to filter
assert_eq!(
intersection_metadata
.iter()
.filter(|(_, t)| !t.is_nan())
.count(),
2
);
assert_eq!(intersection_metadata[0], (2, 0.5));
assert_eq!(intersection_metadata[1], (7, 0.5));

Expand Down