Skip to content

Commit

Permalink
fix: simplify the lazy approach (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored Sep 1, 2024
1 parent 9f5925c commit 44e1536
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 93 deletions.
10 changes: 4 additions & 6 deletions crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
lazy::LazyTraversal, marzano_context::MarzanoContext,
marzano_resolved_pattern::MarzanoResolvedPattern, paths::resolve, problem::MarzanoQueryContext,
marzano_context::MarzanoContext, marzano_resolved_pattern::MarzanoResolvedPattern,
paths::resolve, problem::MarzanoQueryContext,
};
use anyhow::{anyhow, bail, Result};
use grit_pattern_matcher::{
Expand Down Expand Up @@ -39,8 +39,7 @@ pub type CallbackFn = dyn for<'a, 'b> Fn(
&'b <crate::problem::MarzanoQueryContext as grit_pattern_matcher::context::QueryContext>::ResolvedPattern<'a>,
&'a MarzanoContext<'a>,
&mut State<'a, MarzanoQueryContext>,
&mut AnalysisLogs,
&mut LazyTraversal<'a, 'b>
&mut AnalysisLogs
) -> Result<bool>
+ Send
+ Sync;
Expand Down Expand Up @@ -126,8 +125,7 @@ impl BuiltIns {
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> Result<bool> {
let mut lazy = LazyTraversal::new(binding);
(self.callbacks[call.callback_index])(binding, context, state, logs, &mut lazy)
(self.callbacks[call.callback_index])(binding, context, state, logs)
}

/// Add an anonymous built-in, used for callbacks
Expand Down
98 changes: 29 additions & 69 deletions crates/core/src/lazy.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,3 @@
use grit_pattern_matcher::pattern::Matcher;
use grit_pattern_matcher::pattern::Pattern;
use grit_pattern_matcher::pattern::State;
use grit_util::error::GritResult;
use grit_util::AnalysisLogs;

use crate::marzano_context::MarzanoContext;
use crate::{marzano_resolved_pattern::MarzanoResolvedPattern, problem::MarzanoQueryContext};

#[derive(Debug, Clone)]
pub struct LazyTraversal<'a, 'b> {
root: &'b MarzanoResolvedPattern<'a>,
}

impl<'a, 'b> LazyTraversal<'a, 'b> {
pub(crate) fn new(root: &'b MarzanoResolvedPattern<'a>) -> Self {
Self { root }
}

#[allow(dead_code)]
pub(crate) fn matches(
&mut self,
pattern: Pattern<MarzanoQueryContext>,
context: &'a MarzanoContext<'a>,
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> GritResult<bool> {
// THIS IS UNSAFE
// TODO: make this safe / improve the lifetimes so pattern does not need to be static
let borrowed_pattern: &'static Pattern<MarzanoQueryContext> =
unsafe { std::mem::transmute(&pattern) };

let matches = borrowed_pattern.execute(self.root, state, context, logs)?;
Ok(matches)
}
}

#[cfg(test)]
mod test {

Expand Down Expand Up @@ -75,40 +38,37 @@ mod test {
assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));

let mut builder = PatternBuilder::start_empty(src, lang).unwrap();
builder =
builder.matches_callback(Box::new(move |_binding, context, state, logs, lazy| {
assert!(state.find_var_in_scope("$foo").is_some());
assert!(state.find_var_in_scope("$bar").is_some());
assert!(state.find_var_in_scope("$dude").is_none());
assert!(state.find_var_in_scope("$baz").is_none());
let _registered_var = state.register_var("fuzz");
assert!(state.find_var_in_scope("fuzz").is_some());

let pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"name".to_owned(),
)),
None,
)));
assert!(lazy.matches(pattern, context, state, logs).unwrap());

let non_matching_pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"not_found".to_owned(),
)),
None,
)));
assert!(!lazy
.matches(non_matching_pattern, context, state, logs)
.unwrap());

callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(true)
}));
builder = builder.matches_callback(Box::new(move |binding, context, state, logs| {
assert!(state.find_var_in_scope("$foo").is_some());
assert!(state.find_var_in_scope("$bar").is_some());
assert!(state.find_var_in_scope("$dude").is_none());
assert!(state.find_var_in_scope("$baz").is_none());
let _registered_var = state.register_var("fuzz");
assert!(state.find_var_in_scope("fuzz").is_some());

let pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"name".to_owned(),
)),
None,
)));
assert!(binding.matches(&pattern, state, context, logs).unwrap());

let non_matching_pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"not_found".to_owned(),
)),
None,
)));
assert!(!binding
.matches(&non_matching_pattern, state, context, logs)
.unwrap());

callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(true)
}));
let CompilationResult { problem, .. } = builder.compile(None, None, true).unwrap();

println!("problem: {:?}", problem);

let test_files = vec![SyntheticFile::new(
"file.js".to_owned(),
r#"function myLogger() {
Expand Down
29 changes: 23 additions & 6 deletions crates/core/src/marzano_resolved_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
marzano_binding::MarzanoBinding, marzano_code_snippet::MarzanoCodeSnippet,
marzano_context::MarzanoContext, paths::absolutize, problem::MarzanoQueryContext,
};
use grit_pattern_matcher::pattern::Matcher;
use grit_pattern_matcher::{
binding::Binding,
constant::Constant,
Expand Down Expand Up @@ -46,6 +47,24 @@ impl<'a> MarzanoResolvedPattern<'a> {
Self::from_binding(MarzanoBinding::List(node, field_id))
}

/// Check if a pattern matches a provided pattern
///
/// Note this leaks memory, so should only be used in short-lived programs
#[allow(dead_code)]
pub(crate) fn matches(
&self,
pattern: &Pattern<MarzanoQueryContext>,
state: &mut State<'a, MarzanoQueryContext>,
context: &'a MarzanoContext<'a>,
logs: &mut AnalysisLogs,
) -> GritResult<bool> {
let borrowed_pattern: &'static Pattern<MarzanoQueryContext> =
Box::leak(Box::new(pattern.clone()));

let matches = borrowed_pattern.execute(self, state, context, logs)?;
Ok(matches)
}

fn to_snippets(&self) -> GritResult<Vector<ResolvedSnippet<'a, MarzanoQueryContext>>> {
match self {
MarzanoResolvedPattern::Snippets(snippets) => Ok(snippets.clone()),
Expand Down Expand Up @@ -486,12 +505,10 @@ impl<'a> ResolvedPattern<'a, MarzanoQueryContext> for MarzanoResolvedPattern<'a>
Pattern::CallBuiltIn(built_in) => built_in.call(state, context, logs),
Pattern::CallFunction(func) => func.call(state, context, logs),
Pattern::CallForeignFunction(func) => func.call(state, context, logs),
Pattern::CallbackPattern(callback) => {
Err(GritPatternError::new(format!(
"cannot make resolved pattern from callback pattern {}",
callback.name()
)))
}
Pattern::CallbackPattern(callback) => Err(GritPatternError::new(format!(
"cannot make resolved pattern from callback pattern {}",
callback.name()
))),
Pattern::StringConstant(string) => Ok(Self::Snippets(vector![ResolvedSnippet::Text(
(&string.text).into(),
)])),
Expand Down
16 changes: 14 additions & 2 deletions crates/core/src/pattern_compiler/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use grit_pattern_matcher::{
ABSOLUTE_PATH_INDEX, DEFAULT_FILE_NAME, FILENAME_INDEX, NEW_FILES_INDEX, PROGRAM_INDEX,
},
pattern::{
And, GritFunctionDefinition, Pattern, PatternDefinition, Predicate, PredicateDefinition,
VariableSourceLocations, Where,
Accumulate, And, DynamicPattern, GritFunctionDefinition, Pattern, PatternDefinition,
Predicate, PredicateDefinition, Rewrite, VariableSourceLocations, Where,
},
};
use grit_util::{AnalysisLogs, Ast, FileRange};
Expand Down Expand Up @@ -235,6 +235,18 @@ impl PatternBuilder {
Self { pattern, ..self }
}

/// Add a rewrite around the pattern
pub fn wrap_with_rewrite(self, replacement: DynamicPattern<MarzanoQueryContext>) -> Self {
let pattern = Pattern::Rewrite(Box::new(Rewrite::new(self.pattern, replacement, None)));
Self { pattern, ..self }
}

/// Wrap with accumulate
pub fn wrap_with_accumulate(self, other: Pattern<MarzanoQueryContext>) -> Self {
let pattern = Pattern::Accumulate(Box::new(Accumulate::new(self.pattern, other, None)));
Self { pattern, ..self }
}

/// Restrict the pattern
pub fn matches(self, other: Pattern<MarzanoQueryContext>) -> Self {
let joined = Pattern::And(Box::new(And::new(vec![self.pattern, other])));
Expand Down
7 changes: 3 additions & 4 deletions crates/core/src/test_callback.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use grit_pattern_matcher::pattern::{ResolvedPattern};
use grit_pattern_matcher::{constants::DEFAULT_FILE_NAME};
use grit_pattern_matcher::constants::DEFAULT_FILE_NAME;
use grit_pattern_matcher::pattern::ResolvedPattern;
use marzano_language::{grit_parser::MarzanoGritParser, target_language::TargetLanguage};
use std::{
path::Path,
sync::{atomic::AtomicBool, Arc},
};


use crate::{
pattern_compiler::{CompilationResult, PatternBuilder},
test_utils::{run_on_test_files, SyntheticFile},
Expand All @@ -27,7 +26,7 @@ fn test_callback() {
assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));

let mut builder = PatternBuilder::start_empty(src, lang).unwrap();
builder = builder.matches_callback(Box::new(move |binding, context, state, _, _| {
builder = builder.matches_callback(Box::new(move |binding, context, state, _logs| {
let text = binding
.text(&state.files, context.language)
.unwrap()
Expand Down
12 changes: 6 additions & 6 deletions crates/grit-pattern-matcher/src/pattern/accumulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ impl<Q: QueryContext> Matcher<Q> for Accumulate<Q> {
return Ok(false);
};
let resolved = context_node;
let Some(dynamic_right) = &self.dynamic_right else {
return Err(GritPatternError::new(
"Insert right hand side must be a code snippet when LHS is not a variable",
));
};
let left = PatternOrResolved::Resolved(resolved);
let right = ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)?;

let right = if let Some(dynamic_right) = &self.dynamic_right {
ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)?
} else {
ResolvedPattern::from_pattern(&self.right, state, context, logs)?
};

insert_effect(&left, right, state, context)
}
Expand Down
6 changes: 6 additions & 0 deletions crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ impl<Q: QueryContext> DynamicPattern<Q> {
let resolved = Q::ResolvedPattern::from_dynamic_pattern(self, state, context, logs)?;
Ok(resolved.text(&state.files, context.language())?.to_string())
}

/// Create a constant DynamicPattern from a string.
pub fn from_str_constant(s: &str) -> GritResult<Self> {
let parts = vec![DynamicSnippetPart::String(s.to_string())];
Ok(DynamicPattern::Snippet(DynamicSnippet { parts }))
}
}

impl<Q: QueryContext> PatternName for DynamicPattern<Q> {
Expand Down

0 comments on commit 44e1536

Please sign in to comment.