Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster Rust programs #12

Merged
merged 2 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ serde_json = "1.0.107"
[profile.release]
lto = true
codegen-units = 1
debug = "line-tables-only"
132 changes: 95 additions & 37 deletions rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{cmp::Reverse, collections::BinaryHeap, time::Instant};
use std::{
collections::BinaryHeap,
time::{Duration, Instant},
};

use rustc_data_structures::fx::FxHashMap;
use serde::{Deserialize, Serialize};
Expand All @@ -16,19 +19,51 @@ struct Post {
tags: Vec<String>,
}

const NUM_TOP_ITEMS: usize = 5;

#[derive(Serialize)]
struct RelatedPosts<'a> {
_id: &'a String,
tags: &'a Vec<String>,
related: Vec<&'a Post>,
}

fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();
#[derive(Eq)]
struct PostCount {
post: usize,
count: usize,
}

let start = Instant::now();
impl std::cmp::PartialEq for PostCount {
fn eq(&self, other: &Self) -> bool {
self.count == other.count
}
}

impl std::cmp::PartialOrd for PostCount {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl std::cmp::Ord for PostCount {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count)
}
}

fn least_n<T: Ord>(n: usize, mut from: impl Iterator<Item = T>) -> impl Iterator<Item = T> {
let mut h = BinaryHeap::from_iter(from.by_ref().take(n));
for it in from {
let mut greatest = h.peek_mut().unwrap();
if it < *greatest {
*greatest = it;
}
}
h.into_iter()
}

fn process(posts: &[Post]) -> Vec<RelatedPosts<'_>> {
let mut post_tags_map: FxHashMap<&String, Vec<usize>> = FxHashMap::default();

for (i, post) in posts.iter().enumerate() {
Expand All @@ -37,45 +72,68 @@ fn main() {
}
}

let mut related_posts: Vec<RelatedPosts> = Vec::with_capacity(posts.len());

for (idx, post) in posts.iter().enumerate() {
// faster than allocating outside the loop
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
posts
.iter()
.enumerate()
.map(|(idx, post)| {
// faster than allocating outside the loop
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
}
}
}
}
}

let mut top_five = BinaryHeap::new();
tagged_post_count
.into_iter()
.enumerate()
.for_each(|(post, count)| {
if top_five.len() < 5 {
top_five.push((Reverse(count), post));
} else {
let (Reverse(cnt), _) = top_five.peek().unwrap();
if count > *cnt {
top_five.pop();
top_five.push((Reverse(count), post));
}
}
});
let top = least_n(
NUM_TOP_ITEMS,
tagged_post_count
.iter()
.enumerate()
.map(|(post, &count)| PostCount { post, count }),
);
let related = top.map(|it| &posts[it.post]).collect();

RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
}
})
.collect()
}

related_posts.push(RelatedPosts {
_id: &post._id,
tags: &post.tags,
related: top_five.into_iter().map(|(_, post)| &posts[post]).collect(),
});
fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();

let start = Instant::now();
let args: Vec<_> = std::env::args().collect();
match &args[..] {
[_progname] => {}
[_progname, secs] => {
let target = start + Duration::from_secs(secs.parse().unwrap());
let mut i = 0;
let mut now;
loop {
process(&posts);
i += 1;
now = Instant::now();
if now > target {
break;
}
}
println!("{:?} per iteration", (now - start) / i);
return;
}
_ => panic!("invalid arguments"),
}

let related_posts = process(&posts);
let end = Instant::now();

print!(
Expand Down
1 change: 1 addition & 0 deletions rust_rayon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ serde_json = "1.0.107"
[profile.release]
lto = true
codegen-units = 1
debug = "line-tables-only"
136 changes: 94 additions & 42 deletions rust_rayon/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::{
cmp::Reverse,
collections::BinaryHeap,
sync::{Arc, Mutex},
time::Instant,
time::{Duration, Instant},
};

use rayon::prelude::*;
Expand All @@ -22,19 +20,51 @@ struct Post {
tags: Vec<String>,
}

const NUM_TOP_ITEMS: usize = 5;

#[derive(Debug, Serialize)]
struct RelatedPosts<'a> {
_id: &'a String,
tags: &'a Vec<String>,
related: Vec<&'a Post>,
}

fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();
#[derive(Eq)]
struct PostCount {
post: usize,
count: usize,
}

let start = Instant::now();
impl std::cmp::PartialEq for PostCount {
fn eq(&self, other: &Self) -> bool {
self.count == other.count
}
}

impl std::cmp::PartialOrd for PostCount {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl std::cmp::Ord for PostCount {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count)
}
}

fn least_n<T: Ord>(n: usize, mut from: impl Iterator<Item = T>) -> impl Iterator<Item = T> {
let mut h = BinaryHeap::from_iter(from.by_ref().take(n));
for it in from {
let mut greatest = h.peek_mut().unwrap();
if it < *greatest {
*greatest = it;
}
}
h.into_iter()
}

fn process(posts: &[Post]) -> Vec<RelatedPosts<'_>> {
let mut post_tags_map: FxHashMap<&String, Vec<usize>> = FxHashMap::default();

for (i, post) in posts.iter().enumerate() {
Expand All @@ -43,54 +73,76 @@ fn main() {
}
}

let related_posts: Arc<Mutex<Vec<RelatedPosts>>> =
Arc::new(Mutex::new(Vec::with_capacity(posts.len())));

posts.par_iter().enumerate().for_each(|(idx, post)| {
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
posts
.par_iter()
.enumerate()
.map(|(idx, post)| {
let mut tagged_post_count = vec![0; posts.len()];
tagged_post_count.fill(0);

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
}
}
}
}
}

let mut top_five = BinaryHeap::new();
tagged_post_count
.into_iter()
.enumerate()
.for_each(|(post, count)| {
if top_five.len() < 5 {
top_five.push((Reverse(count), post));
} else {
let (Reverse(cnt), _) = top_five.peek().unwrap();
if count > *cnt {
top_five.pop();
top_five.push((Reverse(count), post));
}
}
});
let top = least_n(
NUM_TOP_ITEMS,
tagged_post_count
.iter()
.enumerate()
.map(|(post, &count)| PostCount { post, count }),
);

let related = top.map(|it| &posts[it.post]).collect();

RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
}
})
.collect()
}

let related = top_five.into_iter().map(|(_, post)| &posts[post]).collect();
fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();

related_posts.lock().unwrap().push(RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
});
});
let start = Instant::now();
let args: Vec<_> = std::env::args().collect();
match &args[..] {
[_progname] => {}
[_progname, secs] => {
let target = start + Duration::from_secs(secs.parse().unwrap());
let mut i = 0;
let mut now;
loop {
process(&posts);
i += 1;
now = Instant::now();
if now > target {
break;
}
}
println!("{:?} per iteration", (now - start) / i);
return;
}
_ => panic!("invalid arguments"),
}

let related_posts = process(&posts);
let end = Instant::now();

print!(
"Processing time (w/o IO): {:?}\n",
end.duration_since(start)
);

let json_str = serde_json::to_string(related_posts.lock().unwrap().as_slice()).unwrap();
let json_str = serde_json::to_string(&related_posts).unwrap();
std::fs::write("../related_posts_rust_rayon.json", json_str).unwrap();
}