Skip to content

Commit

Permalink
Merge pull request #12 from scottlamb/faster
Browse files Browse the repository at this point in the history
faster Rust programs
  • Loading branch information
jinyus authored Sep 24, 2023
2 parents 1a7e6b2 + 0cbab20 commit d47badc
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 79 deletions.
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ serde_json = "1.0.107"
[profile.release]
lto = true
codegen-units = 1
debug = "line-tables-only"
132 changes: 95 additions & 37 deletions rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{cmp::Reverse, collections::BinaryHeap, time::Instant};
use std::{
collections::BinaryHeap,
time::{Duration, Instant},
};

use rustc_data_structures::fx::FxHashMap;
use serde::{Deserialize, Serialize};
Expand All @@ -16,19 +19,51 @@ struct Post {
tags: Vec<String>,
}

const NUM_TOP_ITEMS: usize = 5;

#[derive(Serialize)]
struct RelatedPosts<'a> {
_id: &'a String,
tags: &'a Vec<String>,
related: Vec<&'a Post>,
}

fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();
#[derive(Eq)]
struct PostCount {
post: usize,
count: usize,
}

let start = Instant::now();
impl std::cmp::PartialEq for PostCount {
fn eq(&self, other: &Self) -> bool {
self.count == other.count
}
}

impl std::cmp::PartialOrd for PostCount {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl std::cmp::Ord for PostCount {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count)
}
}

fn least_n<T: Ord>(n: usize, mut from: impl Iterator<Item = T>) -> impl Iterator<Item = T> {
let mut h = BinaryHeap::from_iter(from.by_ref().take(n));
for it in from {
let mut greatest = h.peek_mut().unwrap();
if it < *greatest {
*greatest = it;
}
}
h.into_iter()
}

fn process(posts: &[Post]) -> Vec<RelatedPosts<'_>> {
let mut post_tags_map: FxHashMap<&String, Vec<usize>> = FxHashMap::default();

for (i, post) in posts.iter().enumerate() {
Expand All @@ -37,45 +72,68 @@ fn main() {
}
}

let mut related_posts: Vec<RelatedPosts> = Vec::with_capacity(posts.len());

for (idx, post) in posts.iter().enumerate() {
// faster than allocating outside the loop
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
posts
.iter()
.enumerate()
.map(|(idx, post)| {
// faster than allocating outside the loop
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
}
}
}
}
}

let mut top_five = BinaryHeap::new();
tagged_post_count
.into_iter()
.enumerate()
.for_each(|(post, count)| {
if top_five.len() < 5 {
top_five.push((Reverse(count), post));
} else {
let (Reverse(cnt), _) = top_five.peek().unwrap();
if count > *cnt {
top_five.pop();
top_five.push((Reverse(count), post));
}
}
});
let top = least_n(
NUM_TOP_ITEMS,
tagged_post_count
.iter()
.enumerate()
.map(|(post, &count)| PostCount { post, count }),
);
let related = top.map(|it| &posts[it.post]).collect();

RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
}
})
.collect()
}

related_posts.push(RelatedPosts {
_id: &post._id,
tags: &post.tags,
related: top_five.into_iter().map(|(_, post)| &posts[post]).collect(),
});
fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();

let start = Instant::now();
let args: Vec<_> = std::env::args().collect();
match &args[..] {
[_progname] => {}
[_progname, secs] => {
let target = start + Duration::from_secs(secs.parse().unwrap());
let mut i = 0;
let mut now;
loop {
process(&posts);
i += 1;
now = Instant::now();
if now > target {
break;
}
}
println!("{:?} per iteration", (now - start) / i);
return;
}
_ => panic!("invalid arguments"),
}

let related_posts = process(&posts);
let end = Instant::now();

print!(
Expand Down
1 change: 1 addition & 0 deletions rust_rayon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ serde_json = "1.0.107"
[profile.release]
lto = true
codegen-units = 1
debug = "line-tables-only"
136 changes: 94 additions & 42 deletions rust_rayon/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::{
cmp::Reverse,
collections::BinaryHeap,
sync::{Arc, Mutex},
time::Instant,
time::{Duration, Instant},
};

use rayon::prelude::*;
Expand All @@ -22,19 +20,51 @@ struct Post {
tags: Vec<String>,
}

const NUM_TOP_ITEMS: usize = 5;

#[derive(Debug, Serialize)]
struct RelatedPosts<'a> {
_id: &'a String,
tags: &'a Vec<String>,
related: Vec<&'a Post>,
}

fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();
#[derive(Eq)]
struct PostCount {
post: usize,
count: usize,
}

let start = Instant::now();
impl std::cmp::PartialEq for PostCount {
fn eq(&self, other: &Self) -> bool {
self.count == other.count
}
}

impl std::cmp::PartialOrd for PostCount {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl std::cmp::Ord for PostCount {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.count.cmp(&self.count)
}
}

fn least_n<T: Ord>(n: usize, mut from: impl Iterator<Item = T>) -> impl Iterator<Item = T> {
let mut h = BinaryHeap::from_iter(from.by_ref().take(n));
for it in from {
let mut greatest = h.peek_mut().unwrap();
if it < *greatest {
*greatest = it;
}
}
h.into_iter()
}

fn process(posts: &[Post]) -> Vec<RelatedPosts<'_>> {
let mut post_tags_map: FxHashMap<&String, Vec<usize>> = FxHashMap::default();

for (i, post) in posts.iter().enumerate() {
Expand All @@ -43,54 +73,76 @@ fn main() {
}
}

let related_posts: Arc<Mutex<Vec<RelatedPosts>>> =
Arc::new(Mutex::new(Vec::with_capacity(posts.len())));

posts.par_iter().enumerate().for_each(|(idx, post)| {
let mut tagged_post_count = vec![0; posts.len()];

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
posts
.par_iter()
.enumerate()
.map(|(idx, post)| {
let mut tagged_post_count = vec![0; posts.len()];
tagged_post_count.fill(0);

for tag in &post.tags {
if let Some(tag_posts) = post_tags_map.get(tag) {
for &other_post_idx in tag_posts {
if idx != other_post_idx {
tagged_post_count[other_post_idx] += 1;
}
}
}
}
}

let mut top_five = BinaryHeap::new();
tagged_post_count
.into_iter()
.enumerate()
.for_each(|(post, count)| {
if top_five.len() < 5 {
top_five.push((Reverse(count), post));
} else {
let (Reverse(cnt), _) = top_five.peek().unwrap();
if count > *cnt {
top_five.pop();
top_five.push((Reverse(count), post));
}
}
});
let top = least_n(
NUM_TOP_ITEMS,
tagged_post_count
.iter()
.enumerate()
.map(|(post, &count)| PostCount { post, count }),
);

let related = top.map(|it| &posts[it.post]).collect();

RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
}
})
.collect()
}

let related = top_five.into_iter().map(|(_, post)| &posts[post]).collect();
fn main() {
let json_str = std::fs::read_to_string("../posts.json").unwrap();
let posts: Vec<Post> = from_str(&json_str).unwrap();

related_posts.lock().unwrap().push(RelatedPosts {
_id: &post._id,
tags: &post.tags,
related,
});
});
let start = Instant::now();
let args: Vec<_> = std::env::args().collect();
match &args[..] {
[_progname] => {}
[_progname, secs] => {
let target = start + Duration::from_secs(secs.parse().unwrap());
let mut i = 0;
let mut now;
loop {
process(&posts);
i += 1;
now = Instant::now();
if now > target {
break;
}
}
println!("{:?} per iteration", (now - start) / i);
return;
}
_ => panic!("invalid arguments"),
}

let related_posts = process(&posts);
let end = Instant::now();

print!(
"Processing time (w/o IO): {:?}\n",
end.duration_since(start)
);

let json_str = serde_json::to_string(related_posts.lock().unwrap().as_slice()).unwrap();
let json_str = serde_json::to_string(&related_posts).unwrap();
std::fs::write("../related_posts_rust_rayon.json", json_str).unwrap();
}

0 comments on commit d47badc

Please sign in to comment.