diff --git a/crates/core/src/built_in_functions.rs b/crates/core/src/built_in_functions.rs index 454e35a81..9bbcf2512 100644 --- a/crates/core/src/built_in_functions.rs +++ b/crates/core/src/built_in_functions.rs @@ -1,6 +1,6 @@ use crate::{ - lazy::LazyTraversal, marzano_context::MarzanoContext, - marzano_resolved_pattern::MarzanoResolvedPattern, paths::resolve, problem::MarzanoQueryContext, + marzano_context::MarzanoContext, marzano_resolved_pattern::MarzanoResolvedPattern, + paths::resolve, problem::MarzanoQueryContext, }; use anyhow::{anyhow, bail, Result}; use grit_pattern_matcher::{ @@ -39,8 +39,7 @@ pub type CallbackFn = dyn for<'a, 'b> Fn( &'b ::ResolvedPattern<'a>, &'a MarzanoContext<'a>, &mut State<'a, MarzanoQueryContext>, - &mut AnalysisLogs, - &mut LazyTraversal<'a, 'b> + &mut AnalysisLogs ) -> Result + Send + Sync; @@ -126,8 +125,7 @@ impl BuiltIns { state: &mut State<'a, MarzanoQueryContext>, logs: &mut AnalysisLogs, ) -> Result { - let mut lazy = LazyTraversal::new(binding); - (self.callbacks[call.callback_index])(binding, context, state, logs, &mut lazy) + (self.callbacks[call.callback_index])(binding, context, state, logs) } /// Add an anonymous built-in, used for callbacks diff --git a/crates/core/src/lazy.rs b/crates/core/src/lazy.rs index fb9f22f11..bbd2050d3 100644 --- a/crates/core/src/lazy.rs +++ b/crates/core/src/lazy.rs @@ -1,40 +1,3 @@ -use grit_pattern_matcher::pattern::Matcher; -use grit_pattern_matcher::pattern::Pattern; -use grit_pattern_matcher::pattern::State; -use grit_util::error::GritResult; -use grit_util::AnalysisLogs; - -use crate::marzano_context::MarzanoContext; -use crate::{marzano_resolved_pattern::MarzanoResolvedPattern, problem::MarzanoQueryContext}; - -#[derive(Debug, Clone)] -pub struct LazyTraversal<'a, 'b> { - root: &'b MarzanoResolvedPattern<'a>, -} - -impl<'a, 'b> LazyTraversal<'a, 'b> { - pub(crate) fn new(root: &'b MarzanoResolvedPattern<'a>) -> Self { - Self { root } - } - - #[allow(dead_code)] - pub(crate) fn matches( - &mut self, - pattern: Pattern, - context: &'a MarzanoContext<'a>, - state: &mut State<'a, MarzanoQueryContext>, - logs: &mut AnalysisLogs, - ) -> GritResult { - // THIS IS UNSAFE - // TODO: make this safe / improve the lifetimes so pattern does not need to be static - let borrowed_pattern: &'static Pattern = - unsafe { std::mem::transmute(&pattern) }; - - let matches = borrowed_pattern.execute(self.root, state, context, logs)?; - Ok(matches) - } -} - #[cfg(test)] mod test { @@ -75,40 +38,37 @@ mod test { assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst)); let mut builder = PatternBuilder::start_empty(src, lang).unwrap(); - builder = - builder.matches_callback(Box::new(move |_binding, context, state, logs, lazy| { - assert!(state.find_var_in_scope("$foo").is_some()); - assert!(state.find_var_in_scope("$bar").is_some()); - assert!(state.find_var_in_scope("$dude").is_none()); - assert!(state.find_var_in_scope("$baz").is_none()); - let _registered_var = state.register_var("fuzz"); - assert!(state.find_var_in_scope("fuzz").is_some()); - - let pattern = Pattern::Contains(Box::new(Contains::new( - Pattern::::StringConstant(StringConstant::new( - "name".to_owned(), - )), - None, - ))); - assert!(lazy.matches(pattern, context, state, logs).unwrap()); - - let non_matching_pattern = Pattern::Contains(Box::new(Contains::new( - Pattern::::StringConstant(StringConstant::new( - "not_found".to_owned(), - )), - None, - ))); - assert!(!lazy - .matches(non_matching_pattern, context, state, logs) - .unwrap()); - - callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst); - Ok(true) - })); + builder = builder.matches_callback(Box::new(move |binding, context, state, logs| { + assert!(state.find_var_in_scope("$foo").is_some()); + assert!(state.find_var_in_scope("$bar").is_some()); + assert!(state.find_var_in_scope("$dude").is_none()); + assert!(state.find_var_in_scope("$baz").is_none()); + let _registered_var = state.register_var("fuzz"); + assert!(state.find_var_in_scope("fuzz").is_some()); + + let pattern = Pattern::Contains(Box::new(Contains::new( + Pattern::::StringConstant(StringConstant::new( + "name".to_owned(), + )), + None, + ))); + assert!(binding.matches(&pattern, state, context, logs).unwrap()); + + let non_matching_pattern = Pattern::Contains(Box::new(Contains::new( + Pattern::::StringConstant(StringConstant::new( + "not_found".to_owned(), + )), + None, + ))); + assert!(!binding + .matches(&non_matching_pattern, state, context, logs) + .unwrap()); + + callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst); + Ok(true) + })); let CompilationResult { problem, .. } = builder.compile(None, None, true).unwrap(); - println!("problem: {:?}", problem); - let test_files = vec![SyntheticFile::new( "file.js".to_owned(), r#"function myLogger() { diff --git a/crates/core/src/marzano_resolved_pattern.rs b/crates/core/src/marzano_resolved_pattern.rs index 6b8ec99ab..a77e6760e 100644 --- a/crates/core/src/marzano_resolved_pattern.rs +++ b/crates/core/src/marzano_resolved_pattern.rs @@ -2,6 +2,7 @@ use crate::{ marzano_binding::MarzanoBinding, marzano_code_snippet::MarzanoCodeSnippet, marzano_context::MarzanoContext, paths::absolutize, problem::MarzanoQueryContext, }; +use grit_pattern_matcher::pattern::Matcher; use grit_pattern_matcher::{ binding::Binding, constant::Constant, @@ -46,6 +47,24 @@ impl<'a> MarzanoResolvedPattern<'a> { Self::from_binding(MarzanoBinding::List(node, field_id)) } + /// Check if a pattern matches a provided pattern + /// + /// Note this leaks memory, so should only be used in short-lived programs + #[allow(dead_code)] + pub(crate) fn matches( + &self, + pattern: &Pattern, + state: &mut State<'a, MarzanoQueryContext>, + context: &'a MarzanoContext<'a>, + logs: &mut AnalysisLogs, + ) -> GritResult { + let borrowed_pattern: &'static Pattern = + Box::leak(Box::new(pattern.clone())); + + let matches = borrowed_pattern.execute(self, state, context, logs)?; + Ok(matches) + } + fn to_snippets(&self) -> GritResult>> { match self { MarzanoResolvedPattern::Snippets(snippets) => Ok(snippets.clone()), @@ -486,12 +505,10 @@ impl<'a> ResolvedPattern<'a, MarzanoQueryContext> for MarzanoResolvedPattern<'a> Pattern::CallBuiltIn(built_in) => built_in.call(state, context, logs), Pattern::CallFunction(func) => func.call(state, context, logs), Pattern::CallForeignFunction(func) => func.call(state, context, logs), - Pattern::CallbackPattern(callback) => { - Err(GritPatternError::new(format!( - "cannot make resolved pattern from callback pattern {}", - callback.name() - ))) - } + Pattern::CallbackPattern(callback) => Err(GritPatternError::new(format!( + "cannot make resolved pattern from callback pattern {}", + callback.name() + ))), Pattern::StringConstant(string) => Ok(Self::Snippets(vector![ResolvedSnippet::Text( (&string.text).into(), )])), diff --git a/crates/core/src/pattern_compiler/builder.rs b/crates/core/src/pattern_compiler/builder.rs index 1e38a8178..f8c88ba53 100644 --- a/crates/core/src/pattern_compiler/builder.rs +++ b/crates/core/src/pattern_compiler/builder.rs @@ -20,8 +20,8 @@ use grit_pattern_matcher::{ ABSOLUTE_PATH_INDEX, DEFAULT_FILE_NAME, FILENAME_INDEX, NEW_FILES_INDEX, PROGRAM_INDEX, }, pattern::{ - And, GritFunctionDefinition, Pattern, PatternDefinition, Predicate, PredicateDefinition, - VariableSourceLocations, Where, + Accumulate, And, DynamicPattern, GritFunctionDefinition, Pattern, PatternDefinition, + Predicate, PredicateDefinition, Rewrite, VariableSourceLocations, Where, }, }; use grit_util::{AnalysisLogs, Ast, FileRange}; @@ -235,6 +235,18 @@ impl PatternBuilder { Self { pattern, ..self } } + /// Add a rewrite around the pattern + pub fn wrap_with_rewrite(self, replacement: DynamicPattern) -> Self { + let pattern = Pattern::Rewrite(Box::new(Rewrite::new(self.pattern, replacement, None))); + Self { pattern, ..self } + } + + /// Wrap with accumulate + pub fn wrap_with_accumulate(self, other: Pattern) -> Self { + let pattern = Pattern::Accumulate(Box::new(Accumulate::new(self.pattern, other, None))); + Self { pattern, ..self } + } + /// Restrict the pattern pub fn matches(self, other: Pattern) -> Self { let joined = Pattern::And(Box::new(And::new(vec![self.pattern, other]))); diff --git a/crates/core/src/test_callback.rs b/crates/core/src/test_callback.rs index 60714a8ac..bae3cf0be 100644 --- a/crates/core/src/test_callback.rs +++ b/crates/core/src/test_callback.rs @@ -1,12 +1,11 @@ -use grit_pattern_matcher::pattern::{ResolvedPattern}; -use grit_pattern_matcher::{constants::DEFAULT_FILE_NAME}; +use grit_pattern_matcher::constants::DEFAULT_FILE_NAME; +use grit_pattern_matcher::pattern::ResolvedPattern; use marzano_language::{grit_parser::MarzanoGritParser, target_language::TargetLanguage}; use std::{ path::Path, sync::{atomic::AtomicBool, Arc}, }; - use crate::{ pattern_compiler::{CompilationResult, PatternBuilder}, test_utils::{run_on_test_files, SyntheticFile}, @@ -27,7 +26,7 @@ fn test_callback() { assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst)); let mut builder = PatternBuilder::start_empty(src, lang).unwrap(); - builder = builder.matches_callback(Box::new(move |binding, context, state, _, _| { + builder = builder.matches_callback(Box::new(move |binding, context, state, _logs| { let text = binding .text(&state.files, context.language) .unwrap() diff --git a/crates/grit-pattern-matcher/src/pattern/accumulate.rs b/crates/grit-pattern-matcher/src/pattern/accumulate.rs index 590707852..b22046fca 100644 --- a/crates/grit-pattern-matcher/src/pattern/accumulate.rs +++ b/crates/grit-pattern-matcher/src/pattern/accumulate.rs @@ -58,13 +58,13 @@ impl Matcher for Accumulate { return Ok(false); }; let resolved = context_node; - let Some(dynamic_right) = &self.dynamic_right else { - return Err(GritPatternError::new( - "Insert right hand side must be a code snippet when LHS is not a variable", - )); - }; let left = PatternOrResolved::Resolved(resolved); - let right = ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)?; + + let right = if let Some(dynamic_right) = &self.dynamic_right { + ResolvedPattern::from_dynamic_pattern(dynamic_right, state, context, logs)? + } else { + ResolvedPattern::from_pattern(&self.right, state, context, logs)? + }; insert_effect(&left, right, state, context) } diff --git a/crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs b/crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs index 0b4f7e168..5844e2183 100644 --- a/crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs +++ b/crates/grit-pattern-matcher/src/pattern/dynamic_snippet.rs @@ -49,6 +49,12 @@ impl DynamicPattern { let resolved = Q::ResolvedPattern::from_dynamic_pattern(self, state, context, logs)?; Ok(resolved.text(&state.files, context.language())?.to_string()) } + + /// Create a constant DynamicPattern from a string. + pub fn from_str_constant(s: &str) -> GritResult { + let parts = vec![DynamicSnippetPart::String(s.to_string())]; + Ok(DynamicPattern::Snippet(DynamicSnippet { parts })) + } } impl PatternName for DynamicPattern {