Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: simplify the lazy approach #483

Merged
merged 5 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
lazy::LazyTraversal, marzano_context::MarzanoContext,
marzano_resolved_pattern::MarzanoResolvedPattern, paths::resolve, problem::MarzanoQueryContext,
marzano_context::MarzanoContext, marzano_resolved_pattern::MarzanoResolvedPattern,
paths::resolve, problem::MarzanoQueryContext,
};
use anyhow::{anyhow, bail, Result};
use grit_pattern_matcher::{
Expand Down Expand Up @@ -39,8 +39,7 @@ pub type CallbackFn = dyn for<'a, 'b> Fn(
&'b <crate::problem::MarzanoQueryContext as grit_pattern_matcher::context::QueryContext>::ResolvedPattern<'a>,
&'a MarzanoContext<'a>,
&mut State<'a, MarzanoQueryContext>,
&mut AnalysisLogs,
&mut LazyTraversal<'a, 'b>
&mut AnalysisLogs
) -> Result<bool>
+ Send
+ Sync;
Expand Down Expand Up @@ -126,8 +125,7 @@ impl BuiltIns {
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> Result<bool> {
let mut lazy = LazyTraversal::new(binding);
(self.callbacks[call.callback_index])(binding, context, state, logs, &mut lazy)
(self.callbacks[call.callback_index])(binding, context, state, logs)
}

/// Add an anonymous built-in, used for callbacks
Expand Down
98 changes: 29 additions & 69 deletions crates/core/src/lazy.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,3 @@
use grit_pattern_matcher::pattern::Matcher;
use grit_pattern_matcher::pattern::Pattern;
use grit_pattern_matcher::pattern::State;
use grit_util::error::GritResult;
use grit_util::AnalysisLogs;

use crate::marzano_context::MarzanoContext;
use crate::{marzano_resolved_pattern::MarzanoResolvedPattern, problem::MarzanoQueryContext};

#[derive(Debug, Clone)]
pub struct LazyTraversal<'a, 'b> {
root: &'b MarzanoResolvedPattern<'a>,
}

impl<'a, 'b> LazyTraversal<'a, 'b> {
pub(crate) fn new(root: &'b MarzanoResolvedPattern<'a>) -> Self {
Self { root }
}

#[allow(dead_code)]
pub(crate) fn matches(
&mut self,
pattern: Pattern<MarzanoQueryContext>,
context: &'a MarzanoContext<'a>,
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> GritResult<bool> {
// THIS IS UNSAFE
// TODO: make this safe / improve the lifetimes so pattern does not need to be static
let borrowed_pattern: &'static Pattern<MarzanoQueryContext> =
unsafe { std::mem::transmute(&pattern) };

let matches = borrowed_pattern.execute(self.root, state, context, logs)?;
Ok(matches)
}
}

#[cfg(test)]
mod test {

Expand Down Expand Up @@ -75,40 +38,37 @@ mod test {
assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));

let mut builder = PatternBuilder::start_empty(src, lang).unwrap();
builder =
builder.matches_callback(Box::new(move |_binding, context, state, logs, lazy| {
assert!(state.find_var_in_scope("$foo").is_some());
assert!(state.find_var_in_scope("$bar").is_some());
assert!(state.find_var_in_scope("$dude").is_none());
assert!(state.find_var_in_scope("$baz").is_none());
let _registered_var = state.register_var("fuzz");
assert!(state.find_var_in_scope("fuzz").is_some());

let pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"name".to_owned(),
)),
None,
)));
assert!(lazy.matches(pattern, context, state, logs).unwrap());

let non_matching_pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"not_found".to_owned(),
)),
None,
)));
assert!(!lazy
.matches(non_matching_pattern, context, state, logs)
.unwrap());

callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(true)
}));
builder = builder.matches_callback(Box::new(move |binding, context, state, logs| {
assert!(state.find_var_in_scope("$foo").is_some());
assert!(state.find_var_in_scope("$bar").is_some());
assert!(state.find_var_in_scope("$dude").is_none());
assert!(state.find_var_in_scope("$baz").is_none());
let _registered_var = state.register_var("fuzz");
assert!(state.find_var_in_scope("fuzz").is_some());

let pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"name".to_owned(),
)),
None,
)));
assert!(binding.matches(&pattern, state, context, logs).unwrap());

let non_matching_pattern = Pattern::Contains(Box::new(Contains::new(
Pattern::<MarzanoQueryContext>::StringConstant(StringConstant::new(
"not_found".to_owned(),
)),
None,
)));
assert!(!binding
.matches(&non_matching_pattern, state, context, logs)
.unwrap());

callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(true)
}));
let CompilationResult { problem, .. } = builder.compile(None, None, true).unwrap();

println!("problem: {:?}", problem);

let test_files = vec![SyntheticFile::new(
"file.js".to_owned(),
r#"function myLogger() {
Expand Down
29 changes: 23 additions & 6 deletions crates/core/src/marzano_resolved_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
marzano_binding::MarzanoBinding, marzano_code_snippet::MarzanoCodeSnippet,
marzano_context::MarzanoContext, paths::absolutize, problem::MarzanoQueryContext,
};
use grit_pattern_matcher::pattern::Matcher;
use grit_pattern_matcher::{
binding::Binding,
constant::Constant,
Expand Down Expand Up @@ -46,6 +47,24 @@ impl<'a> MarzanoResolvedPattern<'a> {
Self::from_binding(MarzanoBinding::List(node, field_id))
}

/// Check if a pattern matches a provided pattern
///
/// Note this leaks memory, so should only be used in short-lived programs
#[allow(dead_code)]
pub(crate) fn matches(
&self,
pattern: &Pattern<MarzanoQueryContext>,
state: &mut State<'a, MarzanoQueryContext>,
context: &'a MarzanoContext<'a>,
logs: &mut AnalysisLogs,
) -> GritResult<bool> {
let borrowed_pattern: &'static Pattern<MarzanoQueryContext> =
Box::leak(Box::new(pattern.clone()));

let matches = borrowed_pattern.execute(self, state, context, logs)?;
Ok(matches)
}

fn to_snippets(&self) -> GritResult<Vector<ResolvedSnippet<'a, MarzanoQueryContext>>> {
match self {
MarzanoResolvedPattern::Snippets(snippets) => Ok(snippets.clone()),
Expand Down Expand Up @@ -486,12 +505,10 @@ impl<'a> ResolvedPattern<'a, MarzanoQueryContext> for MarzanoResolvedPattern<'a>
Pattern::CallBuiltIn(built_in) => built_in.call(state, context, logs),
Pattern::CallFunction(func) => func.call(state, context, logs),
Pattern::CallForeignFunction(func) => func.call(state, context, logs),
Pattern::CallbackPattern(callback) => {
Err(GritPatternError::new(format!(
"cannot make resolved pattern from callback pattern {}",
callback.name()
)))
}
Pattern::CallbackPattern(callback) => Err(GritPatternError::new(format!(
"cannot make resolved pattern from callback pattern {}",
callback.name()
))),
Pattern::StringConstant(string) => Ok(Self::Snippets(vector![ResolvedSnippet::Text(
(&string.text).into(),
)])),
Expand Down
16 changes: 14 additions & 2 deletions crates/core/src/pattern_compiler/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use grit_pattern_matcher::{
ABSOLUTE_PATH_INDEX, DEFAULT_FILE_NAME, FILENAME_INDEX, NEW_FILES_INDEX, PROGRAM_INDEX,
},
pattern::{
And, GritFunctionDefinition, Pattern, PatternDefinition, Predicate, PredicateDefinition,
VariableSourceLocations, Where,
Accumulate, And, DynamicPattern, GritFunctionDefinition, Pattern, PatternDefinition,
Predicate, PredicateDefinition, Rewrite, VariableSourceLocations, Where,
},
};
use grit_util::{AnalysisLogs, Ast, FileRange};
Expand Down Expand Up @@ -235,6 +235,18 @@ impl PatternBuilder {
Self { pattern, ..self }
}

/// Add a rewrite around the pattern
pub fn wrap_with_rewrite(self, replacement: DynamicPattern<MarzanoQueryContext>) -> Self {
let pattern = Pattern::Rewrite(Box::new(Rewrite::new(self.pattern, replacement, None)));
Self { pattern, ..self }
}

/// Wrap with accumulate
pub fn wrap_with_accumulate(self, other: Pattern<MarzanoQueryContext>) -> Self {
let pattern = Pattern::Accumulate(Box::new(Accumulate::new(self.pattern, other, None)));
Self { pattern, ..self }
}

/// Restrict the pattern
pub fn matches(self, other: Pattern<MarzanoQueryContext>) -> Self {
let joined = Pattern::And(Box::new(And::new(vec![self.pattern, other])));
Expand Down
7 changes: 3 additions & 4 deletions crates/core/src/test_callback.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use grit_pattern_matcher::pattern::{ResolvedPattern};
use grit_pattern_matcher::{constants::DEFAULT_FILE_NAME};
use grit_pattern_matcher::constants::DEFAULT_FILE_NAME;
use grit_pattern_matcher::pattern::ResolvedPattern;
use marzano_language::{grit_parser::MarzanoGritParser, target_language::TargetLanguage};
use std::{
path::Path,
sync::{atomic::AtomicBool, Arc},
};


use crate::{
pattern_compiler::{CompilationResult, PatternBuilder},
test_utils::{run_on_test_files, SyntheticFile},
Expand All @@ -27,7 +26,7 @@ fn test_callback() {
assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));

let mut builder = PatternBuilder::start_empty(src, lang).unwrap();
builder = builder.matches_callback(Box::new(move |binding, context, state, _, _| {
builder = builder.matches_callback(Box::new(move |binding, context, state, _logs| {
let text = binding
.text(&state.files, context.language)
.unwrap()
Expand Down
12 changes: 6 additions & 6 deletions crates/grit-pattern-matcher/src/pattern/accumulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ impl<Q: QueryContext> Matcher<Q> for Accumulate<Q> {
return Ok(false);
};
let resolved = context_node;
let Some(dynamic_right) = &self.dynamic_right else {
return Err(GritPatternError::new(
"Insert right hand side must be a code snippet when LHS is not a variable",
));
};
let left = PatternOrResolved::Resolved(resolved);
let right = ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)?;

let right = if let Some(dynamic_right) = &self.dynamic_right {
ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)?
} else {
ResolvedPattern::from_pattern(&self.right, state, context, logs)?
};

insert_effect(&left, right, state, context)
}
Expand Down
6 changes: 6 additions & 0 deletions crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ impl<Q: QueryContext> DynamicPattern<Q> {
let resolved = Q::ResolvedPattern::from_dynamic_pattern(self, state, context, logs)?;
Ok(resolved.text(&state.files, context.language())?.to_string())
}

/// Create a constant DynamicPattern from a string.
pub fn from_str_constant(s: &str) -> GritResult<Self> {
let parts = vec![DynamicSnippetPart::String(s.to_string())];
Ok(DynamicPattern::Snippet(DynamicSnippet { parts }))
}
}

impl<Q: QueryContext> PatternName for DynamicPattern<Q> {
Expand Down
Loading