From 37ee41a80f5cf9760a411e8282617bfab07e2c10 Mon Sep 17 00:00:00 2001 From: goose Date: Tue, 22 Oct 2024 17:00:26 +0700 Subject: [PATCH] feat(core): deduplicate log targets Previously, one log targets can belong to many log containers. Consider this in Javasript: ``` const foo = { test: () => { const bar = "bar" } } ``` If you log the variable `bar`, it will result in two lines of logs, one for the inner function container, one for the outer declaration container. This PR fixes that by picking the deepest container when there are more than one. We assume that users want the smallest scope possible. --- lua/neolog/actions.lua | 89 ++------ lua/neolog/actions/treesitter.lua | 218 +++++++++++++++++++ lua/neolog/treesitter.lua | 97 --------- lua/neolog/utils.lua | 35 +++ tests/neolog/actions/neolog_actions_spec.lua | 54 +++++ 5 files changed, 324 insertions(+), 169 deletions(-) create mode 100644 lua/neolog/actions/treesitter.lua delete mode 100644 lua/neolog/treesitter.lua diff --git a/lua/neolog/actions.lua b/lua/neolog/actions.lua index ea6f4e1..98227e6 100644 --- a/lua/neolog/actions.lua +++ b/lua/neolog/actions.lua @@ -5,6 +5,7 @@ local M = { log_templates = {}, batch_log_templates = {}, batch = {} } local highlight = require("neolog.highlight") +local treesitter = require("neolog.actions.treesitter") local utils = require("neolog.utils") ---@param line_number number 1-indexed @@ -108,71 +109,6 @@ local function after_insert_log_statements(statements) end end ----Query all target containers in the current buffer that intersect with the given range ----@alias logable_range {[1]: number, [2]: number} ----@param lang string ----@param range {[1]: number, [2]: number, [3]: number, [4]: number} ----@return {container: TSNode, logable_range: logable_range?}[] -local function query_log_target_container(lang, range) - local bufnr = vim.api.nvim_get_current_buf() - local parser = vim.treesitter.get_parser(bufnr, lang) - local tree = parser:parse()[1] - local root = tree:root() - - local query = vim.treesitter.query.get(lang, "neolog-log-container") - if not query then - vim.notify(string.format("logging_framework doesn't support %s language", lang), vim.log.levels.ERROR) - return {} - end - - local containers = {} - - for _, match, metadata in query:iter_matches(root, bufnr, 0, -1) do - ---@type TSNode - local log_container = match[utils.get_key_by_value(query.captures, "log_container")] - - if log_container and utils.ranges_intersect(utils.get_ts_node_range(log_container), range) then - ---@type TSNode? - local logable_range = match[utils.get_key_by_value(query.captures, "logable_range")] - - local logable_range_col_range - - if metadata.adjusted_logable_range then - logable_range_col_range = { - metadata.adjusted_logable_range[1], - metadata.adjusted_logable_range[3], - } - elseif logable_range then - logable_range_col_range = { logable_range:start(), logable_range:end_() } - end - - table.insert(containers, { container = log_container, logable_range = logable_range_col_range }) - end - end - - return containers -end - ----Find all the log target nodes in the given container ----@param container TSNode ----@param lang string ----@return TSNode[] -local function find_log_target(container, lang) - local query = vim.treesitter.query.get(lang, "neolog-log-target") - if not query then - vim.notify(string.format("logging_framework doesn't support %s language", lang), vim.log.levels.ERROR) - return {} - end - - local bufnr = vim.api.nvim_get_current_buf() - local log_targets = {} - for _, node in query:iter_captures(container, bufnr, 0, -1) do - table.insert(log_targets, node) - end - - return log_targets -end - ---@param filetype string ---@return string? local function get_lang(filetype) @@ -206,7 +142,7 @@ local function group_overlapping_log_targets(log_targets) table.insert(current_group, log_target) else -- Check the current node with each node in the current group - -- If it matches any of the node, it belongs to the current group + -- If it intersects with any of the node, it belongs to the current group -- If it not, move it into a new group local insersect_any = utils.array_any(current_group, function(node) return utils.ranges_intersect(utils.get_ts_node_range(node), utils.get_ts_node_range(log_target)) @@ -267,15 +203,20 @@ end ---@return {log_container: TSNode, logable_range: logable_range?, log_targets: TSNode[]}[] local function capture_log_targets(lang) local selection_range = utils.get_selection_range() - local log_containers = query_log_target_container(lang, selection_range) + local log_containers = treesitter.query_log_target_container(lang, selection_range) local result = {} - for _, log_container in ipairs(log_containers) do - local log_targets = find_log_target(log_container.container, lang) + local log_target_grouped_by_container = treesitter.find_log_targets( + utils.array_map(log_containers, function(i) + return i.container + end), + lang + ) + for _, entry in ipairs(log_target_grouped_by_container) do -- Filter targets that intersect with the given range - log_targets = utils.array_filter(log_targets, function(node) + local log_targets = utils.array_filter(entry.log_targets, function(node) return utils.ranges_intersect(selection_range, utils.get_ts_node_range(node)) end) @@ -286,6 +227,11 @@ local function capture_log_targets(lang) return pick_best_node(group, selection_range) end) + local log_container = utils.array_find(log_containers, function(i) + return i.container == entry.container + end) + ---@cast log_container -nil + table.insert(result, { log_container = log_container.container, logable_range = log_container.logable_range, @@ -539,8 +485,7 @@ function M.setup(templates, batch_templates) M.log_templates = templates M.batch_log_templates = batch_templates - -- Register the custom directive - require("neolog.treesitter") + treesitter.setup() end return M diff --git a/lua/neolog/actions/treesitter.lua b/lua/neolog/actions/treesitter.lua new file mode 100644 index 0000000..371a77e --- /dev/null +++ b/lua/neolog/actions/treesitter.lua @@ -0,0 +1,218 @@ +local M = {} + +local utils = require("neolog.utils") + +---Sort the given nodes in the order that they would appear in a preorder traversal +local function sort_ts_nodes_preorder(nodes) + return utils.array_sort_with_index(nodes, function(a, b) + local result = utils.compare_ts_node_start(a[1], b[1]) + if result == "equal" then + result = utils.compare_ts_node_end(a[1], b[1]) + + -- It the containers have exactly the same range, sort by the appearance order + return result == "equal" and a[2] < b[2] or result == "after" + else + return result == "before" + end + end) +end + +---Query all target containers in the current buffer that intersect with the given range +---It's possible to have containers which contain one another. They form a subtree. +---In this case, we pick the deepest child in the subtree. +---@alias logable_range {[1]: number, [2]: number} +---@param lang string +---@param range {[1]: number, [2]: number, [3]: number, [4]: number} +---@return {container: TSNode, logable_range: logable_range?}[] +function M.query_log_target_container(lang, range) + local bufnr = vim.api.nvim_get_current_buf() + local parser = vim.treesitter.get_parser(bufnr, lang) + local tree = parser:parse()[1] + local root = tree:root() + + local query = vim.treesitter.query.get(lang, "neolog-log-container") + if not query then + vim.notify(string.format("neolog doesn't support %s language", lang), vim.log.levels.ERROR) + return {} + end + + local containers = {} + + for _, match, metadata in query:iter_matches(root, bufnr, 0, -1) do + ---@type TSNode + local log_container = match[utils.get_key_by_value(query.captures, "log_container")] + + if log_container and utils.ranges_intersect(utils.get_ts_node_range(log_container), range) then + ---@type TSNode? + local logable_range = match[utils.get_key_by_value(query.captures, "logable_range")] + + local logable_range_col_range + + if metadata.adjusted_logable_range then + logable_range_col_range = { + metadata.adjusted_logable_range[1], + metadata.adjusted_logable_range[3], + } + elseif logable_range then + logable_range_col_range = { logable_range:start(), logable_range:end_() } + end + + table.insert(containers, { container = log_container, logable_range = logable_range_col_range }) + end + end + + return containers +end + +---Find all the log target nodes in the given containers +---A log target can belong to multiple containers. In this case, we pick the deepest container +---@param containers TSNode[] +---@param lang string +---@return {container: TSNode, log_targets: TSNode[]}[] +function M.find_log_targets(containers, lang) + local query = vim.treesitter.query.get(lang, "neolog-log-target") + if not query then + vim.notify(string.format("neolog doesn't support %s language", lang), vim.log.levels.ERROR) + return {} + end + + local bufnr = vim.api.nvim_get_current_buf() + local entries = {} + + ---@type { [string]: TSNode } + local log_targets_table = {} + + for _, container in ipairs(containers) do + for _, node in query:iter_captures(container, bufnr, 0, -1) do + table.insert(entries, { log_container = container, log_target = node }) + log_targets_table[node:id()] = node + end + end + + -- Group by log target + local grouped_log_targets = utils.array_group_by(entries, function(i) + return i.log_target:id() + end, function(i) + return i.log_container + end) + + local grouped_log_containers = {} + + -- If there's multiple containers for the same log target, pick the deepest container + for log_target_id, log_containers in pairs(grouped_log_targets) do + local sorted_group = sort_ts_nodes_preorder(log_containers) + local deepest_container = sorted_group[#sorted_group] + + local log_target = log_targets_table[log_target_id] + if grouped_log_containers[deepest_container] then + table.insert(grouped_log_containers[deepest_container].log_targets, log_target) + else + grouped_log_containers[deepest_container] = { container = deepest_container, log_targets = { log_target } } + end + end + + return utils.table_values(grouped_log_containers) +end + +---Check if the given node: +--- 1. Has a parent node of type `parent_type` +--- 2. Is a field `field_name` of the parent node +---@param node TSNode? +---@param parent_type string +---@param field_name string +---@return boolean +local function is_node_field_of_parent(node, parent_type, field_name) + if not node then + return false + end + + local parent = node:parent() + if not parent or parent:type() ~= parent_type then + return false + end + + local field_nodes = parent:field(field_name) + return vim.list_contains(field_nodes, node) +end + +---Check if the given node: +--- 1. Has an ancestor node of type `ancestor_type` +--- 2. Is in the subtree of field `field_name` of the ancestor node +---@param node TSNode? +---@param ancestor_type string +---@param field_name string +---@return boolean +local function is_node_field_of_ancestor(node, ancestor_type, field_name) + local current = node + + while current do + if is_node_field_of_parent(current, ancestor_type, field_name) then + return true + end + + current = current:parent() + end + + return false +end + +function M.setup() + -- Adjust the range of the node + vim.treesitter.query.add_directive("adjust-range!", function(match, _, _, predicate, metadata) + local capture_id = predicate[2] + + ---@type TSNode + local node = match[capture_id] + + -- Get the adjustment values from the predicate arguments + local start_adjust = tonumber(predicate[3]) or 0 + local end_adjust = tonumber(predicate[4]) or 0 + + -- Get the original range + local start_row, start_col, end_row, end_col = node:range() + + -- Adjust the range + local adjusted_start_row = math.max(0, start_row + start_adjust) -- Ensure we don't go below 0 + local adjusted_end_row = math.max(adjusted_start_row, end_row + end_adjust) -- Ensure end is not before start + + -- Store the adjusted range in metadata + metadata.adjusted_logable_range = { adjusted_start_row, start_col, adjusted_end_row, end_col } + end, { force = true }) + + -- Similar to has-parent?, but also check the node is a field of the parent + vim.treesitter.query.add_predicate("field-of-parent?", function(match, _, _, predicate) + local node = match[predicate[2]] + local parent_type = predicate[3] + local field_name = predicate[4] + + return is_node_field_of_parent(node, parent_type, field_name) + end, { force = true }) + + -- The negation of field-of-parent? + vim.treesitter.query.add_predicate("not-field-of-parent?", function(match, _, _, predicate) + local node = match[predicate[2]] + local parent_type = predicate[3] + local field_name = predicate[4] + + return not is_node_field_of_parent(node, parent_type, field_name) + end, { force = true }) + + -- Similar to has-ancestor?, but also check the node is in a field of the ancestor subtree + vim.treesitter.query.add_predicate("field-of-ancestor?", function(match, _, _, predicate) + local node = match[predicate[2]] + local ancestor_type = predicate[3] + local field_name = predicate[4] + + return is_node_field_of_ancestor(node, ancestor_type, field_name) + end, { force = true }) + + vim.treesitter.query.add_predicate("not-field-of-ancestor?", function(match, _, _, predicate) + local node = match[predicate[2]] + local ancestor_type = predicate[3] + local field_name = predicate[4] + + return not is_node_field_of_ancestor(node, ancestor_type, field_name) + end, { force = true }) +end + +return M diff --git a/lua/neolog/treesitter.lua b/lua/neolog/treesitter.lua deleted file mode 100644 index c7c0531..0000000 --- a/lua/neolog/treesitter.lua +++ /dev/null @@ -1,97 +0,0 @@ ----Check if the given node: ---- 1. Has a parent node of type `parent_type` ---- 2. Is a field `field_name` of the parent node ----@param node TSNode? ----@param parent_type string ----@param field_name string ----@return boolean -local function is_node_field_of_parent(node, parent_type, field_name) - if not node then - return false - end - - local parent = node:parent() - if not parent or parent:type() ~= parent_type then - return false - end - - local field_nodes = parent:field(field_name) - return vim.list_contains(field_nodes, node) -end - ----Check if the given node: ---- 1. Has an ancestor node of type `ancestor_type` ---- 2. Is in the subtree of field `field_name` of the ancestor node ----@param node TSNode? ----@param ancestor_type string ----@param field_name string ----@return boolean -local function is_node_field_of_ancestor(node, ancestor_type, field_name) - local current = node - - while current do - if is_node_field_of_parent(current, ancestor_type, field_name) then - return true - end - - current = current:parent() - end - - return false -end - -vim.treesitter.query.add_directive("adjust-range!", function(match, _, _, predicate, metadata) - local capture_id = predicate[2] - - ---@type TSNode - local node = match[capture_id] - - -- Get the adjustment values from the predicate arguments - local start_adjust = tonumber(predicate[3]) or 0 - local end_adjust = tonumber(predicate[4]) or 0 - - -- Get the original range - local start_row, start_col, end_row, end_col = node:range() - - -- Adjust the range - local adjusted_start_row = math.max(0, start_row + start_adjust) -- Ensure we don't go below 0 - local adjusted_end_row = math.max(adjusted_start_row, end_row + end_adjust) -- Ensure end is not before start - - -- Store the adjusted range in metadata - metadata.adjusted_logable_range = { adjusted_start_row, start_col, adjusted_end_row, end_col } -end, { force = true }) - --- Similar to has-parent?, but also check the node is a field of the parent -vim.treesitter.query.add_predicate("field-of-parent?", function(match, _, _, predicate) - local node = match[predicate[2]] - local parent_type = predicate[3] - local field_name = predicate[4] - - return is_node_field_of_parent(node, parent_type, field_name) -end, { force = true }) - --- The negation of field-of-parent? -vim.treesitter.query.add_predicate("not-field-of-parent?", function(match, _, _, predicate) - local node = match[predicate[2]] - local parent_type = predicate[3] - local field_name = predicate[4] - - return not is_node_field_of_parent(node, parent_type, field_name) -end, { force = true }) - --- Similar to has-ancestor?, but also check the node is in a field of the ancestor subtree -vim.treesitter.query.add_predicate("field-of-ancestor?", function(match, _, _, predicate) - local node = match[predicate[2]] - local ancestor_type = predicate[3] - local field_name = predicate[4] - - return is_node_field_of_ancestor(node, ancestor_type, field_name) -end, { force = true }) - -vim.treesitter.query.add_predicate("not-field-of-ancestor?", function(match, _, _, predicate) - local node = match[predicate[2]] - local ancestor_type = predicate[3] - local field_name = predicate[4] - - return not is_node_field_of_ancestor(node, ancestor_type, field_name) -end, { force = true }) diff --git a/lua/neolog/utils.lua b/lua/neolog/utils.lua index 534cd06..1e77953 100644 --- a/lua/neolog/utils.lua +++ b/lua/neolog/utils.lua @@ -42,6 +42,31 @@ function M.array_any(array, predicate) return false end +function M.array_group_by(array, key_function, value_function) + local result = {} + + key_function = key_function or function(v) + return v + end + + value_function = value_function or function(v) + return v + end + + for _, v in ipairs(array) do + local key = key_function(v) + local value = value_function(v) + + if result[key] then + table.insert(result[key], value) + else + result[key] = { value } + end + end + + return result +end + function M.array_sort_with_index(array, comparator) local with_index = M.array_map(array, function(v, i) return { v, i } @@ -54,6 +79,16 @@ function M.array_sort_with_index(array, comparator) end) end +function M.table_values(t) + local result = {} + + for _, v in pairs(t) do + table.insert(result, v) + end + + return result +end + function M.get_key_by_value(t, value) for k, v in pairs(t) do if v == value then diff --git a/tests/neolog/actions/neolog_actions_spec.lua b/tests/neolog/actions/neolog_actions_spec.lua index 1a0c3da..c7f110e 100644 --- a/tests/neolog/actions/neolog_actions_spec.lua +++ b/tests/neolog/actions/neolog_actions_spec.lua @@ -206,6 +206,60 @@ describe("neolog.actions.insert_log", function() end) end) + describe("a log target belongs to multiple log containers", function() + it("chooses the deepest container", function() + neolog.setup({ + log_templates = { + testing = { + javascript = [[console.log("Testing", %identifier)]], + }, + }, + }) + + local input = [[ + const foo = { + bar: () => { + const ba|z = 123 + }, + }; + ]] + + helper.assert_scenario({ + input = input, + filetype = "javascript", + action = function() + actions.insert_log({ template = "testing", position = "below" }) + end, + expected = [[ + const foo = { + bar: () => { + const baz = 123 + console.log("Testing", baz) + }, + }; + ]], + }) + + helper.assert_scenario({ + input = input, + filetype = "javascript", + action = function() + vim.cmd("normal! Vap") + actions.insert_log({ template = "testing", position = "below" }) + end, + expected = [[ + const foo = { + bar: () => { + const baz = 123 + console.log("Testing", baz) + }, + }; + console.log("Testing", foo) + ]], + }) + end) + end) + it("calls highlight.highlight_add_to_batch for each target", function() neolog.setup()