From e88dd08c82fb623846f1b78c4616178f96928cfc Mon Sep 17 00:00:00 2001 From: Kristijan Husak Date: Mon, 6 Nov 2023 21:37:44 +0100 Subject: [PATCH] refactor(markup): Split parsing markup by type --- lua/orgmode/colors/markup_highlighter.lua | 358 ++++++++++++---------- 1 file changed, 201 insertions(+), 157 deletions(-) diff --git a/lua/orgmode/colors/markup_highlighter.lua b/lua/orgmode/colors/markup_highlighter.lua index 6a9530709..3c590770c 100644 --- a/lua/orgmode/colors/markup_highlighter.lua +++ b/lua/orgmode/colors/markup_highlighter.lua @@ -1,5 +1,6 @@ local config = require('orgmode.config') local ts_utils = require('nvim-treesitter.ts_utils') +---@type Query local query = nil local valid_pre_marker_chars = { ' ', '(', '-', "'", '"', '{', '*', '/', '_', '+' } @@ -75,42 +76,41 @@ local markers = { }, } ----@param node userdata +---@param node TSNode ---@param source number ---@param offset_col_start? number ---@param offset_col_end? number ---@return string local function get_node_text(node, source, offset_col_start, offset_col_end) - local start_row, start_col = node:start() - local end_row, end_col = node:end_() - start_col = start_col + (offset_col_start or 0) - end_col = end_col + (offset_col_end or 0) - - local lines - local eof_row = vim.api.nvim_buf_line_count(source) - if start_row >= eof_row then - return '' - end + local range = { node:range() } + return vim.treesitter.get_node_text(node, source, { + metadata = { + range = { + range[1], + math.max(0, range[2] + (offset_col_start or 0)), + range[3], + math.max(0, range[4] + (offset_col_end or 0)), + }, + }, + }) +end - if end_col == 0 then - lines = vim.api.nvim_buf_get_lines(source, start_row, end_row, true) - end_col = -1 - else - lines = vim.api.nvim_buf_get_lines(source, start_row, end_row + 1, true) +---@param start_node TSNode +---@param end_node TSNode +---@return boolean +local function validate(start_node, end_node) + if not start_node or not end_node then + return false end - if #lines > 0 then - if #lines == 1 then - lines[1] = string.sub(lines[1], start_col + 1, end_col) - else - lines[1] = string.sub(lines[1], start_col + 1) - lines[#lines] = string.sub(lines[#lines], 1, end_col) - end - end + local start_line = start_node:range() + local end_line = end_node:range() - return table.concat(lines, '\n') + return start_line == end_line end +---@param bufnr number +---@return TSNode|nil local get_tree = ts_utils.memoize_by_buf_tick(function(bufnr) local tree = vim.treesitter.get_parser(bufnr, 'org'):parse() if not tree or not #tree then @@ -123,14 +123,9 @@ local function is_valid_markup_range(match, _, source, predicates) local start_node = match[predicates[2]] local end_node = match[predicates[3]] - if not start_node or not end_node then - return - end - - local start_line = start_node:range() - local end_line = end_node:range() + local is_valid = validate(start_node, end_node) - if start_line ~= end_line then + if not is_valid then return false end @@ -151,14 +146,9 @@ local function is_valid_hyperlink_range(match, _, source, predicates) local start_node = match[predicates[2]] local end_node = match[predicates[3]] - if not start_node or not end_node then - return - end - - local start_line = start_node:range() - local end_line = start_node:range() + local is_valid = validate(start_node, end_node) - if start_line ~= end_line then + if not is_valid then return false end @@ -216,86 +206,63 @@ local function load_deps() if query then return end - query = vim.treesitter.query.get('org', 'markup') + + query = vim.treesitter.query.get('org', 'markup') --[[@as Query]] vim.treesitter.query.add_predicate('org-is-valid-markup-range?', is_valid_markup_range) vim.treesitter.query.add_predicate('org-is-valid-hyperlink-range?', is_valid_hyperlink_range) vim.treesitter.query.add_predicate('org-is-valid-latex-range?', is_valid_latex_range) end ----@param bufnr number ----@param line_index number ----@return table -local get_matches = ts_utils.memoize_by_buf_tick(function(bufnr, line_index, root) - local ranges = {} - local taken_locations = {} - - for _, match, _ in query:iter_matches(root, bufnr, line_index, line_index + 1) do - for _, node in pairs(match) do - local char = node:type() - -- saves unnecessary parsing, since \\ is not used below - if char ~= '\\' then - local range = ts_utils.node_to_lsp_range(node) - local linenr = tostring(range.start.line) - taken_locations[linenr] = taken_locations[linenr] or {} - if not taken_locations[linenr][range.start.character] then - table.insert(ranges, { - type = char, - range = range, - node = node, - }) - taken_locations[linenr][range.start.character] = true - end - end - end - end - - table.sort(ranges, function(a, b) +local function sort_entries(entries) + return table.sort(entries, function(a, b) if a.range.start.line == b.range.start.line then return a.range.start.character < b.range.start.character end return a.range.start.line < b.range.start.line end) +end + +local function get_links(entries) + if not entries then + return {} + end + + sort_entries(entries) local seek = {} - local seek_link = {} local result = {} - local link_result = {} - local latex_result = {} - local nested = {} - local can_nest = true - - local type_map = { - ['('] = '\\(', - [')'] = '\\(', - ['}'] = '\\{', - } + for _, item in ipairs(entries) do + if item.type == '[' then + seek = item + end - for _, item in ipairs(ranges) do - if item.type == '(' then - item.range.start.character = item.range.start.character - 1 - elseif item.type == 'str' then - item.range.start.character = item.range.start.character - 1 - local char = get_node_text(item.node, bufnr, 0, 1):sub(-1) - if char == '{' then - item.type = '\\{' - else - item.type = '\\s' - end + if item.type == ']' and seek then + table.insert(result, { + from = seek.range, + to = item.range, + }) + seek = nil end + end - item.type = type_map[item.type] or item.type + return result +end + +local function generate_results(entries, self_contained_check_fn) + local seek = {} + local result = {} + local nested = {} + local can_nest = true + sort_entries(entries) + + for _, item in ipairs(entries) do if markers[item.type] then if seek[item.type] then local from = seek[item.type] if nested[#nested] == nil or nested[#nested] == from.type then - local target_result = result - if markers[item.type].type == 'latex' then - target_result = latex_result - end - - table.insert(target_result, { + table.insert(result, { type = item.type, from = from.range, to = item.range, @@ -316,8 +283,7 @@ local get_matches = ts_utils.memoize_by_buf_tick(function(bufnr, line_index, roo end end elseif can_nest then - -- escaped strings have no pairs, their markup info is self-contained - if item.type == '\\s' then + if self_contained_check_fn and self_contained_check_fn(item) then table.insert(result, { type = item.type, from = item.range, @@ -330,24 +296,92 @@ local get_matches = ts_utils.memoize_by_buf_tick(function(bufnr, line_index, roo end end end + end - if item.type == '[' then - seek_link = item + return result +end + +local function get_markup(entries) + if not entries then + return {} + end + + return generate_results(entries) +end + +local function get_latex(entries, bufnr) + if not entries then + return {} + end + + local type_map = { + ['('] = '\\(', + [')'] = '\\(', + ['}'] = '\\{', + } + + for _, item in ipairs(entries) do + if item.type == '(' then + item.range.start.character = item.range.start.character - 1 + elseif item.type == 'str' then + item.range.start.character = item.range.start.character - 1 + local char = get_node_text(item.node, bufnr, 0, 1):sub(-1) + if char == '{' then + item.type = '\\{' + else + item.type = '\\s' + end end - if item.type == ']' and seek_link then - table.insert(link_result, { - from = seek_link.range, - to = item.range, - }) - seek_link = nil + item.type = type_map[item.type] or item.type + end + + return generate_results(entries, function(item) + return item.type == '\\s' + end) +end + +---@param bufnr number +---@param line_index number +---@return table +local get_matches = ts_utils.memoize_by_buf_tick(function(bufnr, line_index, root) + local ranges = {} + local taken_locations = {} + + for _, match, _ in query:iter_matches(root, bufnr, line_index, line_index + 1) do + for _, node in pairs(match) do + local char = node:type() + local marker = markers[char] + local type = nil + if marker then + type = marker.type + elseif char == '[' or char == ']' then + type = 'link' + elseif char ~= '\\' then + type = 'latex' + end + + if type then + ranges[type] = ranges[type] or {} + local range = ts_utils.node_to_lsp_range(node) + local linenr = tostring(range.start.line) + taken_locations[linenr] = taken_locations[linenr] or {} + if not taken_locations[linenr][range.start.character] then + table.insert(ranges[type], { + type = char, + range = range, + node = node, + }) + taken_locations[linenr][range.start.character] = true + end + end end end return { - ranges = result, - link_ranges = link_result, - latex_ranges = latex_result, + markup_ranges = get_markup(ranges.text), + link_ranges = get_links(ranges.link), + latex_ranges = get_latex(ranges.latex, bufnr), } end, { key = function(bufnr, line_index) @@ -355,105 +389,115 @@ end, { end, }) -local function apply(namespace, bufnr, line_index) - bufnr = bufnr or 0 - local root = get_tree(bufnr) - if not root then - return - end - - local result = get_matches(bufnr, line_index, root) +local function highlight_markup(namespace, bufnr, entries) local hide_markers = config.org_hide_emphasis_markers - - for _, range in ipairs(result.ranges) do + for _, entry in ipairs(entries) do local hl_offset = 0 - if markers[range.type].delimiter_hl then + if markers[entry.type].delimiter_hl then hl_offset = 1 -- Leading delimiter - vim.api.nvim_buf_set_extmark(bufnr, namespace, range.from.start.line, range.from.start.character, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character, { ephemeral = true, - end_col = range.from.start.character + hl_offset, - hl_group = markers[range.type].hl_name .. '_delimiter', - spell = markers[range.type].spell, - priority = 110 + range.from.start.character, + end_col = entry.from.start.character + hl_offset, + hl_group = markers[entry.type].hl_name .. '_delimiter', + spell = markers[entry.type].spell, + priority = 110 + entry.from.start.character, }) -- Closing delimiter - vim.api.nvim_buf_set_extmark(bufnr, namespace, range.from.start.line, range.to['end'].character - hl_offset, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.to['end'].character - hl_offset, { ephemeral = true, - end_col = range.to['end'].character, - hl_group = markers[range.type].hl_name .. '_delimiter', - spell = markers[range.type].spell, - priority = 110 + range.from.start.character, + end_col = entry.to['end'].character, + hl_group = markers[entry.type].hl_name .. '_delimiter', + spell = markers[entry.type].spell, + priority = 110 + entry.from.start.character, }) end -- Main body highlight - vim.api.nvim_buf_set_extmark(bufnr, namespace, range.from.start.line, range.from.start.character + hl_offset, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character + hl_offset, { ephemeral = true, - end_col = range.to['end'].character - hl_offset, - hl_group = markers[range.type].hl_name, - spell = markers[range.type].spell, - priority = 110 + range.from.start.character, + end_col = entry.to['end'].character - hl_offset, + hl_group = markers[entry.type].hl_name, + spell = markers[entry.type].spell, + priority = 110 + entry.from.start.character, }) if hide_markers then - vim.api.nvim_buf_set_extmark(bufnr, namespace, range.from.start.line, range.from.start.character, { - end_col = range.from['end'].character, + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character, { + end_col = entry.from['end'].character, ephemeral = true, conceal = '', }) - vim.api.nvim_buf_set_extmark(bufnr, namespace, range.to.start.line, range.to.start.character, { - end_col = range.to['end'].character, + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.to.start.line, entry.to.start.character, { + end_col = entry.to['end'].character, ephemeral = true, conceal = '', }) end end +end - for _, link_range in ipairs(result.link_ranges) do - local line = vim.api.nvim_buf_get_lines(bufnr, link_range.from.start.line, link_range.from.start.line + 1, false)[1] - local link = line:sub(link_range.from.start.character + 1, link_range.to['end'].character) +local function highlight_links(namespace, bufnr, entries) + for _, entry in ipairs(entries) do + local line = vim.api.nvim_buf_get_lines(bufnr, entry.from.start.line, entry.from.start.line + 1, false)[1] + local link = line:sub(entry.from.start.character + 1, entry.to['end'].character) local alias = link:find('%]%[') or 1 local link_end = link:find('%]%[') or (link:len() - 1) - vim.api.nvim_buf_set_extmark(bufnr, namespace, link_range.from.start.line, link_range.from.start.character, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character, { ephemeral = true, - end_col = link_range.to['end'].character, + end_col = entry.to['end'].character, hl_group = 'org_hyperlink', priority = 110, }) - vim.api.nvim_buf_set_extmark(bufnr, namespace, link_range.from.start.line, link_range.from.start.character, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character, { ephemeral = true, - end_col = link_range.from.start.character + 1 + alias, + end_col = entry.from.start.character + 1 + alias, conceal = '', }) - vim.api.nvim_buf_set_extmark(bufnr, namespace, link_range.from.start.line, link_range.from.start.character + 2, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character + 2, { ephemeral = true, - end_col = link_range.from.start.character - 1 + link_end, + end_col = entry.from.start.character - 1 + link_end, spell = false, }) - vim.api.nvim_buf_set_extmark(bufnr, namespace, link_range.from.start.line, link_range.to['end'].character - 2, { + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.to['end'].character - 2, { ephemeral = true, - end_col = link_range.to['end'].character, + end_col = entry.to['end'].character, conceal = '', }) end +end - for _, latex_range in ipairs(result.latex_ranges) do - vim.api.nvim_buf_set_extmark(bufnr, namespace, latex_range.from.start.line, latex_range.from.start.character, { +local function highlight_latex(namespace, bufnr, entries) + for _, entry in ipairs(entries) do + vim.api.nvim_buf_set_extmark(bufnr, namespace, entry.from.start.line, entry.from.start.character, { ephemeral = true, - end_col = latex_range.to['end'].character, - hl_group = markers[latex_range.type].hl_name, - spell = markers[latex_range.type].spell, - priority = 110 + latex_range.from.start.character, + end_col = entry.to['end'].character, + hl_group = markers[entry.type].hl_name, + spell = markers[entry.type].spell, + priority = 110 + entry.from.start.character, }) end end +local function apply(namespace, bufnr, line_index) + bufnr = bufnr or 0 + local root = get_tree(bufnr) + if not root then + return + end + + local result = get_matches(bufnr, line_index, root) + + highlight_markup(namespace, bufnr, result.markup_ranges) + highlight_links(namespace, bufnr, result.link_ranges) + highlight_latex(namespace, bufnr, result.latex_ranges) +end + local function setup() for _, marker in pairs(markers) do vim.cmd(string.format(marker.hl_cmd, marker.hl_name))