From bdcd74d9ee90f57d278edb6d1aa784e2db039c44 Mon Sep 17 00:00:00 2001 From: Anshuman Date: Thu, 24 Aug 2023 18:41:52 +0800 Subject: [PATCH] feat: nested treesj - use `vim.v.count1` to select parents based on nesting depth - use `flash.nvim` to select parents using labels - dot repeat remembers the nesting depth (configurable) --- lua/treesj/format.lua | 35 +++++++++++- lua/treesj/init.lua | 25 +++++++-- lua/treesj/notify.lua | 1 + lua/treesj/search.lua | 40 ++++++++++++++ lua/treesj/selectors/count.lua | 8 +++ lua/treesj/selectors/flash.lua | 97 ++++++++++++++++++++++++++++++++++ lua/treesj/settings.lua | 2 + 7 files changed, 202 insertions(+), 6 deletions(-) create mode 100644 lua/treesj/selectors/count.lua create mode 100644 lua/treesj/selectors/flash.lua diff --git a/lua/treesj/format.lua b/lua/treesj/format.lua index 55746a8..7761194 100644 --- a/lua/treesj/format.lua +++ b/lua/treesj/format.lua @@ -45,7 +45,7 @@ function M.get_node_at_cursor(root_lang_tree) ) end -function M._format(mode, override) +function M._format(mode, override, selector) -- Tree reparsing is required, otherwise the tree may not be updated -- and each node will be processed only once (until -- the tree is updated). See issue #118 @@ -67,7 +67,7 @@ function M._format(mode, override) -- If the node is marked as "disabled", continue searching from its parent. while true do - found, tsn_data = pcall(search.get_configured_node, start_node) + found, tsn_data = pcall(selector or search.get_configured_node, start_node) if not found then notify.warn(tsn_data) return @@ -161,4 +161,35 @@ function M._format(mode, override) pcall(vim.api.nvim_win_set_cursor, 0, new_cursor) end +M.last_selected = nil +function M._nested(selector, mode, preset) + local nodes + M._format(mode, preset, function(start_node) + if not nodes then + nodes = search.get_configured_nodes(start_node) + end + local selected + if settings.remember_selected and M.last_selected then + -- TODO: basic indexing is not ideal if the sticky cursor doesn't work? + selected = nodes[M.last_selected] + else + if type(selector) == 'string' then + selector = require('treesj.selectors.' .. selector).selector + end + local sel, index = selector(nodes) + if sel then + M.last_selected = index + selected = sel + end + end + if selected then + return selected + else + -- TODO: change the error messages + error(msg.no_chosen_node, 0) + return + end + end) +end + return M diff --git a/lua/treesj/init.lua b/lua/treesj/init.lua index 18f2069..6d92b8e 100644 --- a/lua/treesj/init.lua +++ b/lua/treesj/init.lua @@ -12,16 +12,14 @@ local function repeatable(fn) if settings.settings.dot_repeat then M.__repeat = fn vim.opt.operatorfunc = "v:lua.require'treesj'.__repeat" - vim.api.nvim_feedkeys('g@l', 'n', true) + vim.api.nvim_feedkeys(vim.v.count1 .. 'g@l', 'n', true) else fn() end end M.format = function(mode, preset) - repeatable(function() - require('treesj.format')._format(mode, preset) - end) + M.nested_format('count', mode, preset) end M.toggle = function(preset) @@ -36,4 +34,23 @@ M.split = function(preset) M.format('split', preset) end +M.nested_format = function(selector, mode, preset) + require('treesj.format').last_selected = nil + repeatable(function() + require('treesj.format')._nested(selector, mode, preset) + end) +end + +M.nested_toggle = function(selector, preset) + M.nested_format(selector, nil, preset) +end + +M.nested_join = function(selector, preset) + M.nested_format(selector, 'join', preset) +end + +M.nested_split = function(selector, preset) + M.nested_format(selector, 'split', preset) +end + return M diff --git a/lua/treesj/notify.lua b/lua/treesj/notify.lua index 7940e29..a9155b0 100644 --- a/lua/treesj/notify.lua +++ b/lua/treesj/notify.lua @@ -7,6 +7,7 @@ local M = {} M.msg = { no_detect_node = 'No detected node at cursor', + no_chosen_node = 'Node choice aborted', no_configured_lang = 'Language "%s" is not configured', contains_error = 'The node "%s" or its descendants contain a syntax error and cannot be %s', no_configured_node = 'Node "%s" for lang "%s" is not configured', diff --git a/lua/treesj/search.lua b/lua/treesj/search.lua index af64910..3921adb 100644 --- a/lua/treesj/search.lua +++ b/lua/treesj/search.lua @@ -173,6 +173,46 @@ function M.get_configured_node(node) return data end +---Return the all configured nodes +---@param node TSNode|nil TSNode instance +---@return table +function M.get_configured_nodes(node) + if not node then + error(msg.node_not_received, 0) + end + + local lang = get_node_lang(node) + if not langs[lang] then + error(msg.no_configured_lang:format(lang), 0) + end + local start_node_type = node:type() + + local nodes = {} + local done = {} + while node do + local data = search_node(node, lang) + + if not data or not data.tsnode then + error(msg.no_configured_node:format(start_node_type, lang), 0) + break + end + + local id = table.concat({ data.tsnode:range() }, '-') + if done[id] then + break + else + done[id] = true + end + + nodes[#nodes + 1] = data + node = data.tsnode:parent() + end + + if #nodes == 0 then + error(msg.no_configured_node:format(start_node_type, lang), 0) + end + return nodes +end ---Return the preset for current node if it no contains field 'target_nodes' ---@param tsn_type string TSNode type diff --git a/lua/treesj/selectors/count.lua b/lua/treesj/selectors/count.lua new file mode 100644 index 0000000..ffa27a7 --- /dev/null +++ b/lua/treesj/selectors/count.lua @@ -0,0 +1,8 @@ +local M = {} +local meta_M = {} +M = setmetatable(M, meta_M) +function M.selector(nodes) + local c = vim.v.count1 + return nodes[c], c +end +return M diff --git a/lua/treesj/selectors/flash.lua b/lua/treesj/selectors/flash.lua new file mode 100644 index 0000000..17575d8 --- /dev/null +++ b/lua/treesj/selectors/flash.lua @@ -0,0 +1,97 @@ +local M = {} +local flash = require('flash') +local Pos = require('flash.search.pos') + +local opts = { mode = 'treesj' } +function M.setup(config) + opts = config +end + +function M.selector(nodes) + local currwin = vim.api.nvim_get_current_win() + local buf = vim.api.nvim_get_current_buf() + local line_count = vim.api.nvim_buf_line_count(buf) + local state = flash.jump(vim.tbl_deep_extend('force', { + matcher = function(win, _) + if win ~= currwin then + return {} + end + local ret = {} + local done = {} + -- https://github.com/folke/flash.nvim/blob/967117690bd677cb7b6a87f0bc0077d2c0be3a27/lua/flash/plugins/treesitter.lua#L52 + for i, node in ipairs(nodes) do + local tsn = node.tsnode + local range = { tsn:range() } + local match = { + win = win, + node = node, + pos = { range[1] + 1, range[2] }, + end_pos = { range[3] + 1, range[4] - 1 }, + index = i, + } + + -- If the match is at the end of the buffer, + -- then move it to the last character of the last line. + if match.end_pos[1] > line_count then + match.end_pos[1] = line_count + match.end_pos[2] = #vim.api.nvim_buf_get_lines( + buf, + match.end_pos[1] - 1, + match.end_pos[1], + false + )[1] + elseif match.end_pos[2] == -1 then + -- If the end points to the start of the next line, move it to the + -- end of the previous line. + -- Otherwise operations include the first character of the next line + local line = vim.api.nvim_buf_get_lines( + buf, + match.end_pos[1] - 2, + match.end_pos[1] - 1, + false + )[1] + match.end_pos[1] = match.end_pos[1] - 1 + match.end_pos[2] = #line + end + local id = + table.concat(vim.tbl_flatten({ match.pos, match.end_pos }), '.') + if not done[id] then + done[id] = true + ret[#ret + 1] = match + end + end + for m, match in ipairs(ret) do + match.pos = Pos(match.pos) + match.end_pos = Pos(match.end_pos) + match.depth = #ret - m + end + return ret + end, + action = function(match, state) + state.final_match = match + end, + search = { + multi_window = false, + wrap = true, + incremental = false, + max_length = 0, + }, + label = { + before = true, + after = true, + }, + highlight = { + matches = false, + }, + actions = { + -- TODO: incremental preview/operations + }, + jump = { autojump = true }, + }, opts or {})) + local m = state.final_match + if m then + return m.node, m.index + end +end + +return M diff --git a/lua/treesj/settings.lua b/lua/treesj/settings.lua index 0dca9de..a39a85c 100644 --- a/lua/treesj/settings.lua +++ b/lua/treesj/settings.lua @@ -25,6 +25,8 @@ local DEFAULT_SETTINGS = { dot_repeat = true, ---@type nil|function Callback for treesj error handler. func (err_text, level, ...) on_error = nil, + ---@type boolean Whether dot repeat on nested operations should remember the selected level of nesting + remember_selected = true, } local commands = {