Skip to content

Commit

Permalink
improve implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfv committed Mar 13, 2024
1 parent 869ae64 commit ce690f8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 49 deletions.
17 changes: 10 additions & 7 deletions src/cli/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::activation::get_environment_variables;
use crate::environment::verify_prefix_location_unchanged;
use crate::project::errors::UnsupportedPlatformError;
use crate::task::{
AmbiguousTask, ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory,
AmbiguousTask, CanSkip, ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory,
SearchEnvironments, TaskAndEnvironment, TaskGraph, TaskName,
};
use crate::Project;
Expand Down Expand Up @@ -145,15 +145,18 @@ pub async fn execute(args: Args) -> miette::Result<()> {
}

// check task cache
if executable_task
let task_cache = match executable_task
.can_skip(&lock_file)
.await
.into_diagnostic()?
{
eprintln!("Task can be skipped (cache hit) 🚀");
task_idx += 1;
continue;
}
CanSkip::No(cache) => cache,
CanSkip::Yes => {
eprintln!("Task can be skipped (cache hit) 🚀");
task_idx += 1;
continue;
}
};

// If we don't have a command environment yet, we need to compute it. We lazily compute the
// task environment because we only need the environment if a task is actually executed.
Expand Down Expand Up @@ -183,7 +186,7 @@ pub async fn execute(args: Args) -> miette::Result<()> {

// Update the task cache with the new hash
executable_task
.save_cache(&lock_file)
.save_cache(&lock_file, task_cache)
.await
.into_diagnostic()?;
}
Expand Down
2 changes: 1 addition & 1 deletion src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub const PREFIX_FILE_NAME: &str = "prefix";
pub const ENVIRONMENTS_DIR: &str = "envs";
pub const SOLVE_GROUP_ENVIRONMENTS_DIR: &str = "solve-group-envs";
pub const PYPI_DEPENDENCIES: &str = "pypi-dependencies";
pub const TASK_CACHE_DIR: &str = "task-cache";
pub const TASK_CACHE_DIR: &str = "task-cache-v0";

pub const DEFAULT_ENVIRONMENT_NAME: &str = "default";

Expand Down
97 changes: 57 additions & 40 deletions src/task/executable_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ use std::{
use thiserror::Error;
use tokio::task::JoinHandle;

use super::task_hash::TaskCache;
use super::TaskHash;
use super::task_hash::{InputHashesError, TaskCache, TaskHash};

/// Runs task in project.
#[derive(Default, Debug)]
Expand Down Expand Up @@ -53,6 +52,19 @@ pub enum TaskExecutionError {
FailedToParseShellScript(#[from] FailedToParseShellScript),
}

#[derive(Debug, Error, Diagnostic)]
pub enum CacheUpdateError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
TaskHashError(#[from] InputHashesError),
}

pub enum CanSkip {
Yes,
No(Option<TaskHash>),
}

/// A task that contains enough information to be able to execute it. The lifetime [`'p`] refers to
/// the lifetime of the project that contains the tasks.
#[derive(Clone)]
Expand Down Expand Up @@ -186,66 +198,71 @@ impl<'p> ExecutableTask<'p> {
})
}

/// We store the hashes of the inputs and the outputs of the task in a file in the cache.
/// The current name is something like `run_environment-task_name.json`.
pub(crate) fn cache_name(&self) -> String {
format!(
"{}-{}.json",
self.run_environment.name(),
self.name().unwrap_or("default")
)
}

/// Checks if the task can be skipped. If the task can be skipped, it returns `CanSkip::Yes`.
/// If the task cannot be skipped, it returns `CanSkip::No` and includes the hash of the task
/// that caused the task to not be skipped - we can use this later to update the cache file quickly.
pub(crate) async fn can_skip(
&self,
lock_file: &LockFileDerivedData<'_>,
) -> Result<bool, std::io::Error> {
) -> Result<CanSkip, std::io::Error> {
tracing::info!("Checking if task can be skipped");
let task_cache_folder = self.project().task_cache_folder();

let project_name = self.project().name();
let environment_name = self.run_environment.name();

let cache_name = format!(
"{}-{}-{}.json",
project_name,
environment_name,
self.name().unwrap_or("default")
);

let cache_file = task_cache_folder.join(cache_name);
let cache_name = self.cache_name();
let cache_file = self.project().task_cache_folder().join(cache_name);
if cache_file.exists() {
let cache = std::fs::read_to_string(&cache_file).unwrap();
let cache: TaskCache = serde_json::from_str(&cache).unwrap();
let hash = TaskHash::from_task(self, &lock_file.lock_file).await;
if let Ok(Some(hash)) = hash {
return Ok(hash.computation_hash() == cache.hash);
if hash.computation_hash() != cache.hash {
return Ok(CanSkip::No(Some(hash)));
} else {
return Ok(CanSkip::Yes);
}
}
}
Ok(false)
Ok(CanSkip::No(None))
}

/// Saves the cache of the task. This function will update the cache file with the new hash of
/// the task (inputs and outputs). If the task has no hash, it will not save the cache.
pub(crate) async fn save_cache(
&self,
lock_file: &LockFileDerivedData<'_>,
) -> Result<(), std::io::Error> {
previous_hash: Option<TaskHash>,
) -> Result<(), CacheUpdateError> {
let task_cache_folder = self.project().task_cache_folder();
if !task_cache_folder.exists() {
std::fs::create_dir_all(&task_cache_folder)?;
}
let project_name = self.project().name();
let environment_name = self.run_environment.name();

let cache_name = format!(
"{}-{}-{}.json",
project_name,
environment_name,
self.name().unwrap_or("default")
);

let cache_file = task_cache_folder.join(cache_name);
if let Some(hash) = TaskHash::from_task(self, &lock_file.lock_file)
let cache_file = task_cache_folder.join(self.cache_name());
let new_hash = if let Some(mut previous_hash) = previous_hash {
previous_hash.update_output(self).await?;
previous_hash
} else if let Some(hash) = TaskHash::from_task(self, &lock_file.lock_file)
.await
.unwrap()
{
let cache = TaskCache {
hash: hash.computation_hash(),
};
let cache = serde_json::to_string(&cache).unwrap();
std::fs::write(&cache_file, cache)
hash
} else {
Ok(())
return Ok(());
};

if !task_cache_folder.exists() {
std::fs::create_dir_all(&task_cache_folder)?;
}

let cache = TaskCache {
hash: new_hash.computation_hash(),
};
let cache = serde_json::to_string(&cache).unwrap();
Ok(std::fs::write(&cache_file, cache)?)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use file_hashes::{FileHashes, FileHashesError};
pub use task_hash::{ComputationHash, InputHashes, TaskHash};

pub use executable_task::{
ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, RunOutput,
CanSkip, ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, RunOutput,
TaskExecutionError,
};
pub use task_environment::{
Expand Down
8 changes: 8 additions & 0 deletions src/task/task_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ impl TaskHash {
}))
}

pub async fn update_output(
&mut self,
task: &ExecutableTask<'_>,
) -> Result<(), InputHashesError> {
self.outputs = OutputHashes::from_task(task).await?;
Ok(())
}

/// Computes a single hash for the task.
pub fn computation_hash(&self) -> ComputationHash {
let mut hasher = Xxh3::new();
Expand Down

0 comments on commit ce690f8

Please sign in to comment.